From 9c0b07602aecd30d54c0c03161ef990e4cbff986 Mon Sep 17 00:00:00 2001 From: Luis Date: Fri, 10 Apr 2026 18:25:46 -0400 Subject: [PATCH 01/58] chore: e2e tests update --- .../components/frontend/js/pages/data_catalog.js | 1 - .../components/frontend/js/pages/hygiene_issues.js | 7 ------- .../frontend/js/pages/monitors_dashboard.js | 2 -- .../frontend/js/pages/notification_settings.js | 1 - .../frontend/js/pages/profiling_results.js | 2 -- .../components/frontend/js/pages/profiling_runs.js | 3 --- .../components/frontend/js/pages/schedule_list.js | 2 +- .../components/frontend/js/pages/score_explorer.js | 10 +++++----- .../frontend/js/pages/table_group_list.js | 3 --- .../components/frontend/js/pages/test_results.js | 8 +------- .../ui/components/frontend/js/pages/test_runs.js | 4 ---- .../ui/components/frontend/js/pages/test_suites.js | 1 - .../frontend/js/shared/profiling_results_dialog.js | 3 +-- .../frontend/js/shared/source_data_dialog.js | 3 +-- testgen/ui/static/js/components/alert.js | 3 +-- testgen/ui/static/js/components/breadcrumbs.js | 7 ++----- testgen/ui/static/js/components/button.js | 3 +-- testgen/ui/static/js/components/caption.js | 2 +- testgen/ui/static/js/components/card.js | 3 +-- testgen/ui/static/js/components/checkbox.js | 3 +-- testgen/ui/static/js/components/code.js | 3 +-- testgen/ui/static/js/components/crontab_input.js | 4 ++-- testgen/ui/static/js/components/dialog.js | 12 ++++-------- testgen/ui/static/js/components/dropdown_button.js | 2 +- testgen/ui/static/js/components/empty_state.js | 3 +-- testgen/ui/static/js/components/expansion_panel.js | 3 +-- testgen/ui/static/js/components/file_input.js | 2 +- testgen/ui/static/js/components/frequency_bars.js | 1 + testgen/ui/static/js/components/input.js | 3 +-- testgen/ui/static/js/components/line_chart.js | 1 - testgen/ui/static/js/components/link.js | 3 +-- .../static/js/components/notification_settings.js | 1 - testgen/ui/static/js/components/paginator.js | 14 ++++++-------- testgen/ui/static/js/components/radio_group.js | 2 +- testgen/ui/static/js/components/schedule_list.js | 2 +- testgen/ui/static/js/components/score_breakdown.js | 4 ++-- testgen/ui/static/js/components/select.js | 5 ++--- testgen/ui/static/js/components/slider.js | 2 +- testgen/ui/static/js/components/spark_line.js | 4 ++-- testgen/ui/static/js/components/summary_bar.js | 1 + testgen/ui/static/js/components/summary_counts.js | 2 +- testgen/ui/static/js/components/table.js | 1 - .../ui/static/js/components/table_group_form.js | 4 ++-- testgen/ui/static/js/components/tabs.js | 10 ++++------ testgen/ui/static/js/components/textarea.js | 3 +-- testgen/ui/static/js/components/toggle.js | 2 +- testgen/ui/static/js/components/tree.js | 1 + 47 files changed, 56 insertions(+), 110 deletions(-) diff --git a/testgen/ui/components/frontend/js/pages/data_catalog.js b/testgen/ui/components/frontend/js/pages/data_catalog.js index 079a0ebe..3ed1c10c 100644 --- a/testgen/ui/components/frontend/js/pages/data_catalog.js +++ b/testgen/ui/components/frontend/js/pages/data_catalog.js @@ -220,7 +220,6 @@ const DataCatalog = (/** @type Properties */ props) => { value: getValue(props.table_group_filter_options)?.find((op) => op.selected)?.value ?? null, options: getValue(props.table_group_filter_options) ?? [], style: 'font-size: 14px;', - testId: 'table-group-filter', onChange: (value) => emit('TableGroupSelected', {payload: value}), }), div( diff --git a/testgen/ui/components/frontend/js/pages/hygiene_issues.js b/testgen/ui/components/frontend/js/pages/hygiene_issues.js index 4ee6810b..8f22132a 100644 --- a/testgen/ui/components/frontend/js/pages/hygiene_issues.js +++ b/testgen/ui/components/frontend/js/pages/hygiene_issues.js @@ -563,14 +563,12 @@ const HygieneIssues = (/** @type Properties */ props) => { profilingColumn: van.derive(() => getValue(props.profiling_column) ?? null), onClose: () => emit('ProfilingClosed', {}), width: '50rem', - testId: 'profiling-dialog', }), SourceDataDialog({ emit, sourceData: van.derive(() => getValue(props.source_data) ?? null), onClose: () => emit('SourceDataClosed', {}), renderHeader: HygieneSourceDataHeader, width: '60rem', - testId: 'source-data-dialog', }), // Summary row @@ -616,7 +614,6 @@ const HygieneIssues = (/** @type Properties */ props) => { label: 'Likelihood', value: likelihoodFilter.val, options: LIKELIHOOD_OPTIONS, - testId: 'likelihood-filter', style: 'min-width: 160px', onChange: onLikelihoodChange, allowNull: true, @@ -625,7 +622,6 @@ const HygieneIssues = (/** @type Properties */ props) => { label: 'Table', value: tableFilter.val, options: tableOptions.val, - testId: 'table-filter', style: 'min-width: 160px', filterable: true, onChange: onTableChange, @@ -635,7 +631,6 @@ const HygieneIssues = (/** @type Properties */ props) => { label: 'Column', value: columnFilter.val, options: columnOptions.val, - testId: 'column-filter', style: 'min-width: 160px', filterable: true, acceptNewOptions: true, @@ -646,7 +641,6 @@ const HygieneIssues = (/** @type Properties */ props) => { label: 'Issue Type', value: issueTypeFilter.val, options: issueTypeOptions.val, - testId: 'issue-type-filter', style: 'min-width: 200px', filterable: true, onChange: onIssueTypeChange, @@ -657,7 +651,6 @@ const HygieneIssues = (/** @type Properties */ props) => { label: 'Action', value: actionFilter.val, options: ACTION_OPTIONS, - testId: 'action-filter', style: 'min-width: 160px', onChange: onActionChange, allowNull: true, diff --git a/testgen/ui/components/frontend/js/pages/monitors_dashboard.js b/testgen/ui/components/frontend/js/pages/monitors_dashboard.js index 5809d7ec..f3eb9612 100644 --- a/testgen/ui/components/frontend/js/pages/monitors_dashboard.js +++ b/testgen/ui/components/frontend/js/pages/monitors_dashboard.js @@ -308,7 +308,6 @@ const MonitorsDashboard = (/** @type Properties */ props) => { })), allowNull: false, style: 'font-size: 14px;', - testId: 'table-group-filter', onChange: (value) => emit('SetParamValues', {payload: {table_group_id: value, table_name: null}}), }), () => getValue(props.has_monitor_test_suite) @@ -371,7 +370,6 @@ const MonitorsDashboard = (/** @type Properties */ props) => { width: 230, style: 'font-size: 14px;', icon: 'search', - testId: 'search-tables', value: tableNameFilterValue, onChange: (value, state) => emit('SetParamValues', {payload: {table_name_filter: value, current_page: 0}}), }), diff --git a/testgen/ui/components/frontend/js/pages/notification_settings.js b/testgen/ui/components/frontend/js/pages/notification_settings.js index 9b879fb7..27d6b598 100644 --- a/testgen/ui/components/frontend/js/pages/notification_settings.js +++ b/testgen/ui/components/frontend/js/pages/notification_settings.js @@ -248,7 +248,6 @@ const NotificationSettings = (/** @type Properties */ props) => { title: newNotificationItemForm.isEdit.val ? span({ class: 'notifications--editing' }, 'Edit Notification') : span({ class: 'text-green' }, 'Add Notification'), - testId: 'notification-item-editor', expanded: newNotificationItemForm.isEdit.val, }, div( diff --git a/testgen/ui/components/frontend/js/pages/profiling_results.js b/testgen/ui/components/frontend/js/pages/profiling_results.js index 10adb4eb..e7c4263d 100644 --- a/testgen/ui/components/frontend/js/pages/profiling_results.js +++ b/testgen/ui/components/frontend/js/pages/profiling_results.js @@ -253,7 +253,6 @@ const ProfilingResults = (/** @type Properties */ props) => { label: 'Table', value: tableFilter.val, options: tableOptions.val, - testId: 'table-filter', style: 'min-width: 200px', filterable: true, acceptNewOptions: true, @@ -264,7 +263,6 @@ const ProfilingResults = (/** @type Properties */ props) => { label: 'Column', value: columnFilter.val, options: columnOptions.val, - testId: 'column-filter', style: 'min-width: 200px', filterable: true, acceptNewOptions: true, diff --git a/testgen/ui/components/frontend/js/pages/profiling_runs.js b/testgen/ui/components/frontend/js/pages/profiling_runs.js index 74bfa329..736279ab 100644 --- a/testgen/ui/components/frontend/js/pages/profiling_runs.js +++ b/testgen/ui/components/frontend/js/pages/profiling_runs.js @@ -359,7 +359,6 @@ const Toolbar = ( options: getValue(props.table_group_options) ?? [], allowNull: true, style: 'font-size: 14px;', - testId: 'table-group-filter', onChange: (value) => emit('FilterApplied', { payload: { table_group_id: value } }), }), div( @@ -401,7 +400,6 @@ const Toolbar = ( tooltipPosition: 'left', style: 'background: var(--button-generic-background-color);', onclick: () => emit('RefreshData', {}), - testId: 'profiling-runs-refresh', }), ), ); @@ -426,7 +424,6 @@ const ProfilingRunItem = ( Checkbox({ checked: selected, onChange: (checked) => selected.val = checked, - testId: 'select-profiling-run', }), ) : '', diff --git a/testgen/ui/components/frontend/js/pages/schedule_list.js b/testgen/ui/components/frontend/js/pages/schedule_list.js index 1ced8e54..31559eb5 100644 --- a/testgen/ui/components/frontend/js/pages/schedule_list.js +++ b/testgen/ui/components/frontend/js/pages/schedule_list.js @@ -75,7 +75,7 @@ const ScheduleList = (/** @type Properties */ props) => { const content = div( { id: domId, class: 'flex-column fx-gap-2', style: 'height: 100%; overflow-y: auto;' }, ExpansionPanel( - {title: span({ class: 'text-green' }, 'Add Schedule'), testId: 'scheduler-cron-editor'}, + {title: span({ class: 'text-green' }, 'Add Schedule')}, div( { class: 'flex-row fx-gap-2' }, () => Select({ diff --git a/testgen/ui/components/frontend/js/pages/score_explorer.js b/testgen/ui/components/frontend/js/pages/score_explorer.js index f1744d34..fad49720 100644 --- a/testgen/ui/components/frontend/js/pages/score_explorer.js +++ b/testgen/ui/components/frontend/js/pages/score_explorer.js @@ -394,32 +394,32 @@ const Toolbar = ( div( { class: 'flex-row fx-gap-4 fx-flex-wrap' }, Checkbox({ + testId: 'include-total-score', label: 'Total Score', checked: displayTotalScore, - testId: 'include-total-score', onChange: (checked) => displayTotalScore.val = checked, }), Checkbox({ + testId: 'include-cde-score', label: 'CDE Score', checked: displayCDEScore, - testId: 'include-cde-score', onChange: (checked) => displayCDEScore.val = checked, }), div( { class: 'flex-row fx-gap-4' }, Checkbox({ + testId: 'include-category', label: 'Category:', checked: displayCategory, - testId: 'include-category', onChange: (checked) => displayCategory.val = checked, }), Select({ + testId: 'category-selector', style: 'margin-left: -8px;', height: 40, value: selectedCategory, options: categories.map((c) => ({ value: c, label: TRANSLATIONS[c] })), disabled: van.derive(() => !getValue(displayCategory)), - testId: 'category-selector', }), ), ), @@ -427,10 +427,10 @@ const Toolbar = ( userCanEdit ? div( { class: 'flex-row fx-align-flex-end fx-gap-3' }, Input({ + testId: 'scorecard-name-input', label: 'Scorecard Name', height: 40, value: scoreName, - testId: 'scorecard-name-input', onChange: debounce((name) => scoreName.val = name, 300), }), () => { diff --git a/testgen/ui/components/frontend/js/pages/table_group_list.js b/testgen/ui/components/frontend/js/pages/table_group_list.js index e5c126ba..0e0b819c 100644 --- a/testgen/ui/components/frontend/js/pages/table_group_list.js +++ b/testgen/ui/components/frontend/js/pages/table_group_list.js @@ -208,7 +208,6 @@ const TableGroupList = (props) => { ? div( { class: 'flex-column fx-gap-4' }, ...tableGroups.map((tableGroup) => Card({ - testId: 'table-group-card', class: '', title: div( { class: 'flex-column fx-gap-2 tg-tablegroup--card-title', 'data-testid': 'tablegroup-card-title' }, @@ -448,7 +447,6 @@ const Toolbar = (permissions, connections, selectedConnection, tableGroupNameFil {class: 'flex-row fx-align-flex-end fx-gap-3'}, () => (getValue(connections) ?? [])?.length > 1 ? Select({ - testId: 'connection-select', label: 'Connection', allowNull: true, value: connection, @@ -460,7 +458,6 @@ const Toolbar = (permissions, connections, selectedConnection, tableGroupNameFil }) : '', Input({ - testId: 'table-groups-name-filter', icon: 'search', label: '', placeholder: 'Search table group names', diff --git a/testgen/ui/components/frontend/js/pages/test_results.js b/testgen/ui/components/frontend/js/pages/test_results.js index d2a578fd..1fcafc6e 100644 --- a/testgen/ui/components/frontend/js/pages/test_results.js +++ b/testgen/ui/components/frontend/js/pages/test_results.js @@ -776,7 +776,6 @@ const TestResults = (/** @type Properties */ props) => { label: 'Status', value: statusFilter.val, options: STATUS_FILTER_OPTIONS, - testId: 'status-filter', style: 'min-width: 160px', onChange: onStatusFilterChange, allowNull: true, @@ -785,7 +784,6 @@ const TestResults = (/** @type Properties */ props) => { label: 'Table', value: tableFilter.val, options: tableOptions.val, - testId: 'table-filter', style: 'min-width: 180px', filterable: true, onChange: onTableFilterChange, @@ -795,7 +793,6 @@ const TestResults = (/** @type Properties */ props) => { label: 'Column', value: columnFilter.val, options: columnOptions.val, - testId: 'column-filter', style: 'min-width: 180px', filterable: true, acceptNewOptions: true, @@ -806,7 +803,6 @@ const TestResults = (/** @type Properties */ props) => { label: 'Test Type', value: testTypeFilter.val, options: testTypeOptions.val, - testId: 'test-type-filter', style: 'min-width: 160px', filterable: true, onChange: onTestTypeFilterChange, @@ -816,7 +812,6 @@ const TestResults = (/** @type Properties */ props) => { label: 'Action', value: actionFilter.val, options: ACTION_FILTER_OPTIONS, - testId: 'action-filter', style: 'min-width: 140px', onChange: onActionFilterChange, allowNull: true, @@ -825,7 +820,6 @@ const TestResults = (/** @type Properties */ props) => { label: 'Flagged', value: flaggedFilter.val, options: FLAGGED_FILTER_OPTIONS, - testId: 'flagged-filter', style: 'min-width: 140px', onChange: onFlaggedFilterChange, allowNull: true, @@ -901,7 +895,7 @@ const TestResults = (/** @type Properties */ props) => { { class: 'flex-column fx-flex', style: 'min-width: 0' }, hasData ? Tabs( - { testId: 'test-result-detail' }, + {}, Tab( { label: 'History' }, si.history?.length diff --git a/testgen/ui/components/frontend/js/pages/test_runs.js b/testgen/ui/components/frontend/js/pages/test_runs.js index 4a5b06a3..d95911ec 100644 --- a/testgen/ui/components/frontend/js/pages/test_runs.js +++ b/testgen/ui/components/frontend/js/pages/test_runs.js @@ -355,7 +355,6 @@ const Toolbar = ( options: getValue(props.table_group_options) ?? [], allowNull: true, style: 'font-size: 14px;', - testId: 'table-group-filter', onChange: (value) => emit('FilterApplied', { payload: { table_group_id: value } }), }), () => Select({ @@ -364,7 +363,6 @@ const Toolbar = ( options: getValue(props.test_suite_options) ?? [], allowNull: true, style: 'font-size: 14px;', - testId: 'test-suite-filter', onChange: (value) => emit('FilterApplied', { payload: { test_suite_id: value } }), }), ), @@ -407,7 +405,6 @@ const Toolbar = ( tooltipPosition: 'left', style: 'background: var(--button-generic-background-color);', onclick: () => emit('RefreshData', {}), - testId: 'test-runs-refresh', }), ), ); @@ -433,7 +430,6 @@ const TestRunItem = ( Checkbox({ checked: selected, onChange: (checked) => selected.val = checked, - testId: 'select-test-run', }), ) : '', diff --git a/testgen/ui/components/frontend/js/pages/test_suites.js b/testgen/ui/components/frontend/js/pages/test_suites.js index f9899747..158aea55 100644 --- a/testgen/ui/components/frontend/js/pages/test_suites.js +++ b/testgen/ui/components/frontend/js/pages/test_suites.js @@ -142,7 +142,6 @@ const TestSuites = (/** @type Properties */ props) => { }, }), () => Input({ - testId: 'test-suite-name-filter', icon: 'search', label: '', placeholder: 'Search test suite names', diff --git a/testgen/ui/components/frontend/js/shared/profiling_results_dialog.js b/testgen/ui/components/frontend/js/shared/profiling_results_dialog.js index 6f8d9140..4f87ef43 100644 --- a/testgen/ui/components/frontend/js/shared/profiling_results_dialog.js +++ b/testgen/ui/components/frontend/js/shared/profiling_results_dialog.js @@ -10,7 +10,6 @@ import { ColumnProfilingResults } from '../data_profiling/column_profiling_resul * @param {object} props.profilingColumn - reactive state: set to column data to open, null to close * @param {function} props.onClose - called when dialog is closed * @param {string} [props.width='52rem'] - * @param {string} [props.testId] */ const ProfilingResultsDialog = (props) => { const emit = props.emit; @@ -31,7 +30,7 @@ const ProfilingResultsDialog = (props) => { const columnJson = van.derive(() => columnData.val ? JSON.stringify(columnData.val) : null); return Dialog( - { title: 'Column Profiling Results', open, onClose, width: props.width || '52rem', testId: props.testId }, + { title: 'Column Profiling Results', open, onClose, width: props.width || '52rem' }, () => columnJson.val ? ColumnProfilingResults({ emit, column: columnJson }) : '', ); }; diff --git a/testgen/ui/components/frontend/js/shared/source_data_dialog.js b/testgen/ui/components/frontend/js/shared/source_data_dialog.js index 8645c18c..72a3775b 100644 --- a/testgen/ui/components/frontend/js/shared/source_data_dialog.js +++ b/testgen/ui/components/frontend/js/shared/source_data_dialog.js @@ -15,7 +15,6 @@ const { div, h4, small } = van.tags; * @param {function} props.onClose - called when dialog is closed * @param {function} [props.renderHeader] - (data) => VanJS node for page-specific metadata header * @param {string} [props.width='70rem'] - * @param {string} [props.testId] */ const SourceDataDialog = (props) => { const emit = props.emit; @@ -35,7 +34,7 @@ const SourceDataDialog = (props) => { }; return Dialog( - { title: 'Source Data', open, onClose, width: props.width || '70rem', testId: props.testId }, + { title: 'Source Data', open, onClose, width: props.width || '70rem' }, () => { const d = data.val; if (!d) return ''; diff --git a/testgen/ui/static/js/components/alert.js b/testgen/ui/static/js/components/alert.js index c01f2fc8..76d4e7a5 100644 --- a/testgen/ui/static/js/components/alert.js +++ b/testgen/ui/static/js/components/alert.js @@ -7,7 +7,6 @@ * @property {string?} class * @property {'info'|'success'|'warn'|'error'} type * @property {Function?} onClose - * @property {string?} testId */ import van from '../van.min.js'; import { getValue, loadStylesheet, getRandomId } from '../utils.js'; @@ -32,7 +31,7 @@ const Alert = (/** @type Properties */ props, /** @type Array */ .. { ...props, id: elementId, - 'data-testid': getValue(props.testId) ?? '', + 'data-testid': 'alert', class: () => `tg-alert flex-row ${getValue(props.class) ?? ''} tg-alert-${getValue(props.type)}`, role: 'alert', }, diff --git a/testgen/ui/static/js/components/breadcrumbs.js b/testgen/ui/static/js/components/breadcrumbs.js index 5280dd22..94f7f71a 100644 --- a/testgen/ui/static/js/components/breadcrumbs.js +++ b/testgen/ui/static/js/components/breadcrumbs.js @@ -8,7 +8,6 @@ * @typedef Properties * @type {object} * @property {Array.} breadcrumbs - * @property {string?} testId */ import van from '../van.min.js'; import { getValue, loadStylesheet } from '../utils.js'; @@ -18,10 +17,8 @@ const { a, div, span } = van.tags; const Breadcrumbs = (/** @type Properties */ props) => { loadStylesheet('breadcrumbs', stylesheet); - const testId = getValue(props.testId) ?? ''; - return div( - { class: 'tg-breadcrumbs-wrapper', 'data-testid': testId }, + { class: 'tg-breadcrumbs-wrapper', 'data-testid': 'breadcrumbs' }, () => { const breadcrumbs = getValue(props.breadcrumbs) || []; @@ -30,7 +27,7 @@ const Breadcrumbs = (/** @type Properties */ props) => { breadcrumbs.reduce((items, b, idx) => { const isLastItem = idx === breadcrumbs.length - 1; items.push(a({ - 'data-testid': testId ? `${testId}-item-${idx}` : '', + 'data-testid': 'breadcrumb-item', class: `tg-breadcrumbs--${ isLastItem ? 'current' : 'active'}`, onclick: (event) => { event.preventDefault(); diff --git a/testgen/ui/static/js/components/button.js b/testgen/ui/static/js/components/button.js index e839fc88..31f3b870 100644 --- a/testgen/ui/static/js/components/button.js +++ b/testgen/ui/static/js/components/button.js @@ -15,7 +15,6 @@ * @property {(bool)} loading * @property {('normal' | 'small')?} size * @property {string?} style - * @property {string?} testId */ import { getValue, loadStylesheet } from '../utils.js'; import van from '../van.min.js'; @@ -49,7 +48,7 @@ const Button = (/** @type Properties */ props) => { style: () => `width: ${isIconOnly ? '' : (width ?? '100%')}; ${getValue(props.style)}`, onclick: onClickHandler, disabled: isDisabled, - 'data-testid': getValue(props.testId) ?? '', + 'data-testid': 'button', }, span({class: 'tg-button-focus-state-indicator'}, ''), props.icon ? i({ diff --git a/testgen/ui/static/js/components/caption.js b/testgen/ui/static/js/components/caption.js index 8f7f21f4..e1820356 100644 --- a/testgen/ui/static/js/components/caption.js +++ b/testgen/ui/static/js/components/caption.js @@ -13,7 +13,7 @@ const Caption = (/** @type Properties */ props) => { loadStylesheet('caption', stylesheet); return span( - { class: 'tg-caption', style: props.style }, + { class: 'tg-caption', style: props.style, 'data-testid': 'caption' }, props.content ); } diff --git a/testgen/ui/static/js/components/card.js b/testgen/ui/static/js/components/card.js index 9102947e..c5fac911 100644 --- a/testgen/ui/static/js/components/card.js +++ b/testgen/ui/static/js/components/card.js @@ -7,7 +7,6 @@ * @property {boolean?} border * @property {string?} id * @property {string?} class - * @property {string?} testId */ import { loadStylesheet, getValue } from '../utils.js'; import van from '../van.min.js'; @@ -19,7 +18,7 @@ const Card = (/** @type Properties */ props) => { return div( { id: props.id ?? '', - 'data-testid': props.testId ?? '', + 'data-testid': 'card', class: () => { const classes = ['tg-card']; if (getValue(props.border)) { diff --git a/testgen/ui/static/js/components/checkbox.js b/testgen/ui/static/js/components/checkbox.js index da7ed63f..35c69150 100644 --- a/testgen/ui/static/js/components/checkbox.js +++ b/testgen/ui/static/js/components/checkbox.js @@ -8,7 +8,6 @@ * @property {boolean?} indeterminate * @property {function(boolean, Event)?} onChange * @property {number?} width - * @property {string?} testId * @property {boolean?} disabled */ import van from '../van.min.js'; @@ -31,7 +30,7 @@ const Checkbox = (/** @type Properties */ props) => { return label( { class: 'flex-row fx-gap-2 clickable', - 'data-testid': props.testId ?? props.name ?? '', + 'data-testid': 'checkbox', style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}`, }, input({ diff --git a/testgen/ui/static/js/components/code.js b/testgen/ui/static/js/components/code.js index 414bd968..bada687d 100644 --- a/testgen/ui/static/js/components/code.js +++ b/testgen/ui/static/js/components/code.js @@ -2,7 +2,6 @@ * @typedef Options * @type {object} * @property {string?} id - * @property {string?} testId * @property {string?} class * @property {string?} language - Language for syntax highlighting (e.g. 'sql', 'html'). Omit for no highlighting. */ @@ -32,7 +31,7 @@ const Code = (options, ...children) => { ); const el = div( - { id: domId, class: `tg-code ${options.class ?? ''}`, 'data-testid': options.testId ?? '' }, + { id: domId, class: `tg-code ${options.class ?? ''}`, 'data-testid': 'code' }, pre({}, codeEl), Icon( { diff --git a/testgen/ui/static/js/components/crontab_input.js b/testgen/ui/static/js/components/crontab_input.js index cd60de89..ee607310 100644 --- a/testgen/ui/static/js/components/crontab_input.js +++ b/testgen/ui/static/js/components/crontab_input.js @@ -16,7 +16,7 @@ * @type {object} * @property {(string|null)} id * @property {(string|null)} name - * @property {string?} testId + * @property {string?} class * @property {CronSample?} sample * @property {InitialValue?} value @@ -68,7 +68,7 @@ const CrontabInput = (/** @type Options */ props) => { id: domId, class: () => `tg-crontab-input ${getValue(props.class) ?? ''}`, style: 'position: relative', - 'data-testid': getValue(props.testId) ?? null, + 'data-testid': 'crontab-input', }, div( {onclick: () => { diff --git a/testgen/ui/static/js/components/dialog.js b/testgen/ui/static/js/components/dialog.js index be8f09ef..dbff9e80 100644 --- a/testgen/ui/static/js/components/dialog.js +++ b/testgen/ui/static/js/components/dialog.js @@ -5,7 +5,6 @@ * @property {import('../van.min.js').State} open - Reactive open state * @property {Function} onClose - Called when the dialog is closed (backdrop click or X button) * @property {string} [width] - CSS width value, default '30rem' - * @property {string?} testId */ import van from '../van.min.js'; import { getValue, loadStylesheet } from '../utils.js'; @@ -29,22 +28,19 @@ const { button, div, i, span } = van.tags; * @param {DialogProps} props * @param {...(Element | string)} children - Content rendered in the dialog body */ -const Dialog = ({ title, open, onClose, width = '30rem', testId }, ...children) => { +const Dialog = ({ title, open, onClose, width = '30rem' }, ...children) => { loadStylesheet('dialog', stylesheet); - const testIdValue = getValue(testId) ?? ''; - const overlay = div( { class: 'tg-dialog-overlay', - 'data-testid': testIdValue ? `${testIdValue}-backdrop` : '', style: () => open.val ? '' : 'display: none', onclick: () => onClose(), }, div( { class: 'tg-dialog', - 'data-testid': testIdValue, + 'data-testid': 'dialog', role: 'dialog', 'aria-modal': 'true', tabindex: '-1', @@ -53,13 +49,13 @@ const Dialog = ({ title, open, onClose, width = '30rem', testId }, ...children) }, div( { class: 'tg-dialog-header' }, - span({ 'data-testid': testIdValue ? `${testIdValue}-title` : '', class: 'tg-dialog-title' }, title), + span({ 'data-testid': 'dialog-title', class: 'tg-dialog-title' }, title), ), div({ class: 'tg-dialog-content' }, ...children), button( { class: 'tg-dialog-close', - 'data-testid': testIdValue ? `${testIdValue}-close` : '', + 'data-testid': 'dialog-close', 'aria-label': 'Close', onclick: () => onClose(), }, diff --git a/testgen/ui/static/js/components/dropdown_button.js b/testgen/ui/static/js/components/dropdown_button.js index e97fdce6..7462141b 100644 --- a/testgen/ui/static/js/components/dropdown_button.js +++ b/testgen/ui/static/js/components/dropdown_button.js @@ -45,7 +45,7 @@ const DropdownButton = (props) => { () => { const items = typeof props.items === 'function' ? props.items() : props.items; return div( - { class: 'tg-dropdown-button--menu' }, + { class: 'tg-dropdown-button--menu', 'data-testid': 'dropdown-menu' }, ...items.map(item => div({ class: 'tg-dropdown-button--item', diff --git a/testgen/ui/static/js/components/empty_state.js b/testgen/ui/static/js/components/empty_state.js index d5240c7a..f8497b67 100644 --- a/testgen/ui/static/js/components/empty_state.js +++ b/testgen/ui/static/js/components/empty_state.js @@ -17,7 +17,7 @@ * @property {Link?} link * @property {any?} button * @property {string?} class -* @property {string?} testId + */ import van from '../van.min.js'; import { Card } from '../components/card.js'; @@ -70,7 +70,6 @@ const EmptyState = (/** @type Properties */ props) => { loadStylesheet('empty-state', stylesheet); return Card({ - testId: getValue(props.testId), class: `tg-empty-state flex-column fx-align-flex-center ${getValue(props.class ?? '')}`, content: [ span({ class: 'tg-empty-state--title mb-5' }, props.label), diff --git a/testgen/ui/static/js/components/expansion_panel.js b/testgen/ui/static/js/components/expansion_panel.js index 2cd5dd21..f584b127 100644 --- a/testgen/ui/static/js/components/expansion_panel.js +++ b/testgen/ui/static/js/components/expansion_panel.js @@ -2,7 +2,6 @@ * @typedef Options * @type {object} * @property {string} title - * @property {string?} testId * @property {bool} expanded */ @@ -46,7 +45,7 @@ const ExpansionPanel = (options, ...children) => { }); return div( - { class: 'tg-expansion-panel', 'data-testid': options.testId ?? '' }, + { class: 'tg-expansion-panel', 'data-testid': 'expansion-panel' }, titleDiv, contentDiv, ); diff --git a/testgen/ui/static/js/components/file_input.js b/testgen/ui/static/js/components/file_input.js index 5845c319..3abb66d4 100644 --- a/testgen/ui/static/js/components/file_input.js +++ b/testgen/ui/static/js/components/file_input.js @@ -113,7 +113,7 @@ const FileInput = (options) => { }; return div( - { class: cssClass }, + { class: cssClass, 'data-testid': 'file-input' }, div( { class: 'tg-file-uploader--label text-caption flex-row fx-gap-1' }, options.label, diff --git a/testgen/ui/static/js/components/frequency_bars.js b/testgen/ui/static/js/components/frequency_bars.js index d26073ce..1a65ea38 100644 --- a/testgen/ui/static/js/components/frequency_bars.js +++ b/testgen/ui/static/js/components/frequency_bars.js @@ -36,6 +36,7 @@ const FrequencyBars = (/** @type Properties */ props) => { }); return () => div( + { 'data-testid': 'frequency-bars' }, div( { class: 'mb-2 text-secondary' }, props.title, diff --git a/testgen/ui/static/js/components/input.js b/testgen/ui/static/js/components/input.js index 1efb0924..99eab6e0 100644 --- a/testgen/ui/static/js/components/input.js +++ b/testgen/ui/static/js/components/input.js @@ -30,7 +30,6 @@ * @property {string?} style * @property {string?} type * @property {string?} class - * @property {string?} testId * @property {any?} prefix * @property {number} step * @property {Array?} validators @@ -103,7 +102,7 @@ const Input = (/** @type Properties */ props) => { id: domId, class: () => `flex-column fx-gap-1 tg-input--label ${getValue(props.class) ?? ''}`, style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}; ${getValue(props.style)}`, - 'data-testid': props.testId ?? props.name ?? '', + 'data-testid': 'input', }, div( { class: 'flex-row fx-gap-1 text-caption' }, diff --git a/testgen/ui/static/js/components/line_chart.js b/testgen/ui/static/js/components/line_chart.js index fd16bd06..722588ce 100644 --- a/testgen/ui/static/js/components/line_chart.js +++ b/testgen/ui/static/js/components/line_chart.js @@ -223,7 +223,6 @@ const LineChart = ( tooltipExtraStyle.val = ''; showTooltip.val = false; }, - testId: lineId, }, line, ) diff --git a/testgen/ui/static/js/components/link.js b/testgen/ui/static/js/components/link.js index 630d6d76..c78e821c 100644 --- a/testgen/ui/static/js/components/link.js +++ b/testgen/ui/static/js/components/link.js @@ -18,7 +18,6 @@ * @property {string?} tooltipPosition * @property {boolean?} disabled * @property {((event: any) => void)?} onClick - * @property {string?} testId */ import { getValue, loadStylesheet } from '../utils.js'; import van from '../van.min.js'; @@ -38,7 +37,7 @@ const Link = (/** @type Properties */ props) => { return a( { - 'data-testid': getValue(props.testId) ?? '', + 'data-testid': 'link', class: `tg-link ${getValue(props.underline) ? 'tg-link--underline' : ''} ${getValue(props.disabled) ? 'disabled' : ''} diff --git a/testgen/ui/static/js/components/notification_settings.js b/testgen/ui/static/js/components/notification_settings.js index b3cf9bce..d644f467 100644 --- a/testgen/ui/static/js/components/notification_settings.js +++ b/testgen/ui/static/js/components/notification_settings.js @@ -231,7 +231,6 @@ const NotificationSettings = (/** @type Properties */ props) => { title: () => newNotificationItemForm.isEdit.val ? span({ class: 'notifications--editing' }, 'Edit Notification') : 'Add Notification', - testId: 'notification-item-editor', expanded: panelExpanded, }, div( diff --git a/testgen/ui/static/js/components/paginator.js b/testgen/ui/static/js/components/paginator.js index cd4a4be8..663c5add 100644 --- a/testgen/ui/static/js/components/paginator.js +++ b/testgen/ui/static/js/components/paginator.js @@ -5,7 +5,6 @@ * @property {number} pageSize * @property {number?} pageIndex * @property {function(number)?} onChange - * @property {string?} testId */ import van from '../van.min.js'; @@ -18,7 +17,6 @@ const Paginator = (/** @type Properties */ props) => { loadStylesheet('paginator', stylesheet); const { count, pageSize } = props; - const testId = getValue(props.testId) ?? ''; const pageIndexState = van.derive(() => getValue(props.pageIndex) ?? 0); van.derive(() => { @@ -27,9 +25,9 @@ const Paginator = (/** @type Properties */ props) => { }); return div( - { class: 'tg-paginator', 'data-testid': testId }, + { class: 'tg-paginator', 'data-testid': 'paginator' }, span( - { class: 'tg-paginator--label', 'data-testid': testId ? `${testId}-info` : '' }, + { class: 'tg-paginator--label', 'data-testid': 'paginator-info' }, () => { const pageIndex = pageIndexState.val; const countValue = getValue(count); @@ -40,7 +38,7 @@ const Paginator = (/** @type Properties */ props) => { button( { class: 'tg-paginator--button', - 'data-testid': testId ? `${testId}-first` : '', + 'aria-label': 'First page', onclick: () => pageIndexState.val = 0, disabled: () => pageIndexState.val === 0, }, @@ -49,7 +47,7 @@ const Paginator = (/** @type Properties */ props) => { button( { class: 'tg-paginator--button', - 'data-testid': testId ? `${testId}-prev` : '', + 'aria-label': 'Previous page', onclick: () => pageIndexState.val--, disabled: () => pageIndexState.val === 0, }, @@ -58,7 +56,7 @@ const Paginator = (/** @type Properties */ props) => { button( { class: 'tg-paginator--button', - 'data-testid': testId ? `${testId}-next` : '', + 'aria-label': 'Next page', onclick: () => pageIndexState.val++, disabled: () => pageIndexState.val === Math.ceil(getValue(count) / getValue(pageSize)) - 1, }, @@ -67,7 +65,7 @@ const Paginator = (/** @type Properties */ props) => { button( { class: 'tg-paginator--button', - 'data-testid': testId ? `${testId}-last` : '', + 'aria-label': 'Last page', onclick: () => pageIndexState.val = Math.ceil(getValue(count) / getValue(pageSize)) - 1, disabled: () => pageIndexState.val === Math.ceil(getValue(count) / getValue(pageSize)) - 1, }, diff --git a/testgen/ui/static/js/components/radio_group.js b/testgen/ui/static/js/components/radio_group.js index 97aef2df..1d4b29c2 100644 --- a/testgen/ui/static/js/components/radio_group.js +++ b/testgen/ui/static/js/components/radio_group.js @@ -31,7 +31,7 @@ const RadioGroup = (/** @type Properties */ props) => { const disabled = getValue(props.disabled) ?? false; return div( - { class: () => `tg-radio-group--wrapper ${layout}${disabled ? ' disabled' : ''}`, style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}` }, + { class: () => `tg-radio-group--wrapper ${layout}${disabled ? ' disabled' : ''}`, style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}`, 'data-testid': 'radio-group' }, div( { class: 'text-caption tg-radio-group--label flex-row fx-gap-1' }, props.label, diff --git a/testgen/ui/static/js/components/schedule_list.js b/testgen/ui/static/js/components/schedule_list.js index ccd75e63..a81d590b 100644 --- a/testgen/ui/static/js/components/schedule_list.js +++ b/testgen/ui/static/js/components/schedule_list.js @@ -78,7 +78,7 @@ const ScheduleList = (/** @type Properties */ props) => { const content = div( { id: domId, class: 'flex-column fx-gap-2', style: 'height: 100%; overflow-y: auto;' }, ExpansionPanel( - {title: span({ class: 'text-green' }, 'Add Schedule'), testId: 'scheduler-cron-editor'}, + {title: span({ class: 'text-green' }, 'Add Schedule')}, div( { class: 'flex-row fx-gap-2' }, () => Select({ diff --git a/testgen/ui/static/js/components/score_breakdown.js b/testgen/ui/static/js/components/score_breakdown.js index 717c4d36..ba123cc7 100644 --- a/testgen/ui/static/js/components/score_breakdown.js +++ b/testgen/ui/static/js/components/score_breakdown.js @@ -26,6 +26,7 @@ const ScoreBreakdown = (score, breakdown, category, scoreType, onViewDetails, em () => { const selectedCategory = getValue(category); return Select({ + testId: 'groupby-selector', label: '', value: selectedCategory, options: Object.entries(CATEGORIES) @@ -33,7 +34,6 @@ const ScoreBreakdown = (score, breakdown, category, scoreType, onViewDetails, em .map(([value, label]) => ({ value, label })), height: 32, onChange: (value) => emit('CategoryChanged', { payload: value }), - testId: 'groupby-selector', }); }, span('for'), @@ -45,12 +45,12 @@ const ScoreBreakdown = (score, breakdown, category, scoreType, onViewDetails, em scoreTypeOptions.push('score'); } return Select({ + testId: 'score-type-selector', label: '', value: selectedScoreType, options: scoreTypeOptions.map((s) => ({ label: SCORE_TYPE_LABEL[s], value: s })), height: 32, onChange: (value) => emit('ScoreTypeChanged', { payload: value }), - testId: 'score-type-selector', }); }, ), diff --git a/testgen/ui/static/js/components/select.js b/testgen/ui/static/js/components/select.js index 70bc57bd..5a7776cb 100644 --- a/testgen/ui/static/js/components/select.js +++ b/testgen/ui/static/js/components/select.js @@ -21,7 +21,6 @@ * @property {number?} width * @property {number?} height * @property {string?} style - * @property {string?} testId * @property {number?} portalClass * @property {('top' | 'bottom')?} portalPosition * @property {boolean?} filterable @@ -189,7 +188,7 @@ const Select = (/** @type {Properties} */ props) => { id: domId, class: () => `flex-column fx-gap-1 text-caption tg-select--label ${getValue(props.disabled) ? 'disabled' : ''}`, style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}; ${getValue(props.style)}`, - 'data-testid': getValue(props.testId) ?? '', + 'data-testid': 'select', onclick: (/** @type Event */ event) => { event.stopPropagation(); event.stopImmediatePropagation(); @@ -307,7 +306,7 @@ const MultiSelect = (props) => { id: domId, class: () => `flex-column fx-gap-1 text-caption tg-select--label ${getValue(props.disabled) ? 'disabled' : ''}`, style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}; ${getValue(props.style)}`, - 'data-testid': getValue(props.testId) ?? '', + 'data-testid': 'select', onclick: (/** @type Event */ event) => { event.stopPropagation(); event.stopImmediatePropagation(); diff --git a/testgen/ui/static/js/components/slider.js b/testgen/ui/static/js/components/slider.js index 2582fc8b..e59fc6f4 100644 --- a/testgen/ui/static/js/components/slider.js +++ b/testgen/ui/static/js/components/slider.js @@ -25,7 +25,7 @@ const Slider = (/** @type Properties */ props) => { }; return label( - { class: 'flex-col fx-gap-1 clickable tg-slider--label text-caption' }, + { class: 'flex-col fx-gap-1 clickable tg-slider--label text-caption', 'data-testid': 'slider' }, props.label, input({ type: "range", diff --git a/testgen/ui/static/js/components/spark_line.js b/testgen/ui/static/js/components/spark_line.js index 89985808..ee0f9a4f 100644 --- a/testgen/ui/static/js/components/spark_line.js +++ b/testgen/ui/static/js/components/spark_line.js @@ -8,7 +8,7 @@ * @property {boolean?} interactive * @property {Function?} onPointMouseEnter * @property {Function?} onPointMouseLeave - * @property {string?} testId + * * @typedef Point * @type {object} @@ -35,7 +35,7 @@ const SparkLine = ( ) => { const display = van.derive(() => getValue(options.hidden) === true ? 'none' : ''); return g( - { fill: 'none', opacity: options.opacity ?? 1, style: 'overflow: visible;', 'data-testid': options.testId, display }, + { fill: 'none', opacity: options.opacity ?? 1, style: 'overflow: visible;', 'data-testid': 'sparkline', display }, polyline({ points: line.map(point => `${point.x} ${point.y}`).join(', '), style: `stroke: ${options.color}; stroke-width: ${options.stroke ?? 1};`, diff --git a/testgen/ui/static/js/components/summary_bar.js b/testgen/ui/static/js/components/summary_bar.js index c16dcc61..c133453f 100644 --- a/testgen/ui/static/js/components/summary_bar.js +++ b/testgen/ui/static/js/components/summary_bar.js @@ -25,6 +25,7 @@ const SummaryBar = (/** @type Properties */ props) => { const total = van.derive(() => getValue(props.items).reduce((sum, item) => sum + item.value, 0)); return div( + { 'data-testid': 'summary-bar' }, () => props.label ? div( { class: 'tg-summary-bar--label' }, props.label, diff --git a/testgen/ui/static/js/components/summary_counts.js b/testgen/ui/static/js/components/summary_counts.js index 46f5533a..918e074c 100644 --- a/testgen/ui/static/js/components/summary_counts.js +++ b/testgen/ui/static/js/components/summary_counts.js @@ -19,7 +19,7 @@ const SummaryCounts = (/** @type Properties */ props) => { loadStylesheet('summaryCounts', stylesheet); return div( - { class: 'flex-row fx-gap-5 fx-flex-wrap' }, + { class: 'flex-row fx-gap-5 fx-flex-wrap', 'data-testid': 'summary-counts' }, getValue(props.items).map(item => div( { class: 'flex-row fx-align-stretch fx-gap-2' }, div({ class: 'tg-summary-counts--bar', style: `background-color: ${colorMap[item.color] || item.color};` }), diff --git a/testgen/ui/static/js/components/table.js b/testgen/ui/static/js/components/table.js index ac53acc1..c4016f54 100644 --- a/testgen/ui/static/js/components/table.js +++ b/testgen/ui/static/js/components/table.js @@ -421,7 +421,6 @@ const Paginatior = ( span({class: 'mr-2'}, 'Rows per page:'), Select({ triggerStyle: 'inline', - testId: 'items-per-page', value: itemsPerPage, options: sizeOptions, portalPosition: 'top', diff --git a/testgen/ui/static/js/components/table_group_form.js b/testgen/ui/static/js/components/table_group_form.js index 8fc96dc2..46dd4353 100644 --- a/testgen/ui/static/js/components/table_group_form.js +++ b/testgen/ui/static/js/components/table_group_form.js @@ -400,7 +400,7 @@ const SamplingForm = ( profileSampleMinCount, ) => { return ExpansionPanel( - { title: 'Sampling Parameters', testId: 'sampling-panel' }, + { title: 'Sampling Parameters' }, div( { class: 'flex-column fx-gap-3' }, Checkbox({ @@ -454,7 +454,7 @@ const TaggingForm = ( dataProduct, ) => { return ExpansionPanel( - { title: 'Table Group Tags', testId: 'tags-panel' }, + { title: 'Table Group Tags' }, Input({ name: 'description', class: 'fx-flex mb-3', diff --git a/testgen/ui/static/js/components/tabs.js b/testgen/ui/static/js/components/tabs.js index 23d315d6..d6cbc51c 100644 --- a/testgen/ui/static/js/components/tabs.js +++ b/testgen/ui/static/js/components/tabs.js @@ -19,7 +19,6 @@ const Tab = ({ label }, ...children) => ({ /** * @typedef {Object} TabsProps - * @property {string?} testId * @property {string?} class * * @param {TabsProps} props @@ -28,8 +27,7 @@ const Tab = ({ label }, ...children) => ({ const Tabs = (props, ...tabs) => { loadStylesheet('tabs', stylesheet); - const { testId: testIdProp, ...restProps } = props; - const testId = getValue(testIdProp) ?? ''; + const { ...restProps } = props; const activeTab = van.state(0); @@ -52,7 +50,7 @@ const Tabs = (props, ...tabs) => { ...tabs.map((tab, i) => button({ class: () => `tg-tabs--tab--label ${i === activeTab.val ? 'active' : ''}`, - 'data-testid': testId ? `${testId}-tab-${i}` : '', + 'data-testid': 'tab', onclick: () => (activeTab.val = i), }, tab.label @@ -60,9 +58,9 @@ const Tabs = (props, ...tabs) => { highlightEl, ); - const tabsContainerEl = div({ ...restProps, 'data-testid': testId, class: () => `${getValue(restProps.class) ?? ''} tg-tabs--container` }, + const tabsContainerEl = div({ ...restProps, 'data-testid': 'tabs', class: () => `${getValue(restProps.class) ?? ''} tg-tabs--container` }, labelsContainerEl, - div({ class: "tg-tabs--content", 'data-testid': testId ? `${testId}-panel` : '' }, () => div({class: "tg-tabs--content-inner"}, tabs[activeTab.val].children)), + div({ class: "tg-tabs--content", 'data-testid': 'tab-panel' }, () => div({class: "tg-tabs--content-inner"}, tabs[activeTab.val].children)), ); van.derive(() => { diff --git a/testgen/ui/static/js/components/textarea.js b/testgen/ui/static/js/components/textarea.js index bdfc411a..9e08f9e6 100644 --- a/testgen/ui/static/js/components/textarea.js +++ b/testgen/ui/static/js/components/textarea.js @@ -22,7 +22,6 @@ * @property {string?} class * @property {number?} width * @property {number?} height - * @property {string?} testId * @property {Array?} validators */ import van from '../van.min.js'; @@ -68,7 +67,7 @@ const Textarea = (/** @type Properties */ props) => { id: domId, class: () => `flex-column fx-gap-1 ${getValue(props.class) ?? ''}`, style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}; ${getValue(props.style)}`, - 'data-testid': props.testId ?? props.name ?? '', + 'data-testid': 'textarea', }, div( { class: 'flex-row fx-gap-1 text-caption' }, diff --git a/testgen/ui/static/js/components/toggle.js b/testgen/ui/static/js/components/toggle.js index eb723c38..4911d953 100644 --- a/testgen/ui/static/js/components/toggle.js +++ b/testgen/ui/static/js/components/toggle.js @@ -19,7 +19,7 @@ const Toggle = (/** @type Properties */ props) => { const disabled = props.disabled?.val ?? props.disabled ?? false; return label( - { class: `flex-row fx-gap-2 ${disabled ? '' : 'clickable'}`, style: props.style ?? '', 'data-testid': props.name ?? '' }, + { class: `flex-row fx-gap-2 ${disabled ? '' : 'clickable'}`, style: props.style ?? '', 'data-testid': 'toggle' }, input({ type: 'checkbox', role: 'switch', diff --git a/testgen/ui/static/js/components/tree.js b/testgen/ui/static/js/components/tree.js index b9902269..2e8dcd84 100644 --- a/testgen/ui/static/js/components/tree.js +++ b/testgen/ui/static/js/components/tree.js @@ -90,6 +90,7 @@ const Tree = (/** @type Properties */ props, /** @type any? */ searchOptionsCont { id: props.id, class: () => `flex-column ${getValue(props.classes)}`, + 'data-testid': 'tree', }, Toolbar(treeNodes, multiSelect, props, searchOptionsContent, filtersContent, emit), div( From 571bf4142556473287f19d5adc4bc92dd2e37cbb Mon Sep 17 00:00:00 2001 From: Luis Date: Fri, 17 Apr 2026 09:00:44 -0400 Subject: [PATCH 02/58] refactor(ui): add data-value to help e2e tests --- .../components/frontend/js/pages/hygiene_issues.js | 12 ++++++------ .../frontend/js/pages/quality_dashboard.js | 2 +- .../frontend/js/pages/test_definitions.js | 9 +++++---- .../ui/components/frontend/js/pages/test_results.js | 13 +++++++------ testgen/ui/static/js/components/attribute.js | 10 +++++++--- testgen/ui/static/js/components/card.js | 2 +- testgen/ui/static/js/components/dialog.js | 2 +- testgen/ui/static/js/components/dropdown_button.js | 1 + testgen/ui/static/js/components/input.js | 1 + testgen/ui/static/js/components/score_legend.js | 2 +- testgen/ui/static/js/components/summary_counts.js | 6 +++--- testgen/ui/static/js/components/table.js | 5 +++-- .../ui/static/js/components/table_group_wizard.js | 2 +- testgen/ui/static/js/components/textarea.js | 1 + 14 files changed, 39 insertions(+), 29 deletions(-) diff --git a/testgen/ui/components/frontend/js/pages/hygiene_issues.js b/testgen/ui/components/frontend/js/pages/hygiene_issues.js index 8f22132a..2682f265 100644 --- a/testgen/ui/components/frontend/js/pages/hygiene_issues.js +++ b/testgen/ui/components/frontend/js/pages/hygiene_issues.js @@ -458,7 +458,7 @@ const HygieneIssues = (/** @type Properties */ props) => { // Table header bar (actions above the table) const tableHeader = div( - { class: 'flex-row fx-align-center fx-gap-2 p-2' }, + { 'data-testid': 'table-header', class: 'flex-row fx-align-center fx-gap-2 p-2' }, Toggle({ label: () => { return div( @@ -481,7 +481,7 @@ const HygieneIssues = (/** @type Properties */ props) => { if (!permissions.val.can_disposition) return ''; const disabled = allSelectedArePassed.val; return div( - { class: 'flex-row fx-gap-1' }, + { 'data-testid': 'disposition-actions', class: 'flex-row fx-gap-1' }, Button({ type: 'icon', icon: 'check_circle', tooltip: 'Confirm selected as relevant', disabled, onclick: () => onDisposition('Confirmed') }), Button({ type: 'icon', icon: 'cancel', tooltip: 'Dismiss selected as not relevant', disabled, onclick: () => onDisposition('Dismissed') }), Button({ type: 'icon', icon: 'notifications_off', tooltip: 'Mute selected for future runs', disabled, onclick: () => onDisposition('Inactive') }), @@ -576,14 +576,14 @@ const HygieneIssues = (/** @type Properties */ props) => { { class: 'flex-row fx-gap-5 fx-align-flex-end mb-3 fx-flex-wrap' }, () => othersSummary.val.length ? div( - { class: 'flex-column fx-gap-1' }, + { 'data-testid': 'hygiene-issues-summary', class: 'flex-column fx-gap-1' }, div({ class: 'text-caption' }, 'Hygiene Issues'), SummaryCounts({ items: othersSummary.val }), ) : '', () => piiSummary.val.length ? div( - { class: 'flex-column fx-gap-1' }, + { 'data-testid': 'hygiene-pii-summary', class: 'flex-column fx-gap-1' }, div({ class: 'text-caption' }, 'Potential PII (Risk)'), SummaryCounts({ items: piiSummary.val }), ) @@ -594,7 +594,7 @@ const HygieneIssues = (/** @type Properties */ props) => { div( { class: 'flex-column' }, div({ class: 'text-caption'}, 'Score'), - div({ style: 'font-size: 28px' }, score), + div({ 'data-testid': 'hygiene-score', style: 'font-size: 28px' }, score), ), Button({ type: 'icon', @@ -668,7 +668,7 @@ const HygieneIssues = (/** @type Properties */ props) => { if (!sel) return ''; return div( - { class: 'tg-hi--detail flex-column fx-gap-4' }, + { 'data-testid': 'hygiene-issue-detail', class: 'tg-hi--detail flex-column fx-gap-4' }, div( { class: 'flex-row fx-gap-2 fx-justify-content-flex-end' }, sel.table_name !== '(multi-table)' diff --git a/testgen/ui/components/frontend/js/pages/quality_dashboard.js b/testgen/ui/components/frontend/js/pages/quality_dashboard.js index a592678b..1a9ed470 100644 --- a/testgen/ui/components/frontend/js/pages/quality_dashboard.js +++ b/testgen/ui/components/frontend/js/pages/quality_dashboard.js @@ -129,10 +129,10 @@ const Toolbar = ( label: 'Score Explorer', color: 'primary', style: 'background: var(--button-generic-background-color); width: unset;', + testId: 'scorecards-goto-explorer', onclick: () => emit('LinkClicked', { href: 'quality-dashboard:explorer', params: { project_code: projectSummary.project_code }, - testId: 'scorecards-goto-explorer', }), }), Button({ diff --git a/testgen/ui/components/frontend/js/pages/test_definitions.js b/testgen/ui/components/frontend/js/pages/test_definitions.js index aeff2ab9..d715178e 100644 --- a/testgen/ui/components/frontend/js/pages/test_definitions.js +++ b/testgen/ui/components/frontend/js/pages/test_definitions.js @@ -75,6 +75,7 @@ const BLANK_PARAM_FIELDS = { const ClearFlagButton = ({ disabled, onclick }) => { return withTooltip(btn( { + 'data-testid': 'button', class: 'tg-button tg-icon-button tg-basic-button', disabled, onclick, @@ -394,7 +395,7 @@ const TestDefinitions = (/** @type object */ props) => { // Table header bar: multi-select toggle + edit buttons | dashed separator | disposition buttons + export const tableHeader = div( - { class: 'flex-row fx-align-center fx-gap-2 p-2 fx-flex-wrap' }, + { 'data-testid': 'table-header', class: 'flex-row fx-align-center fx-gap-2 p-2 fx-flex-wrap' }, () => canDisposition.val ? Toggle({ label: () => { @@ -429,7 +430,7 @@ const TestDefinitions = (/** @type object */ props) => { test_type: r.test_type, lock_refresh: r.lock_refresh, })); return div( - { class: 'flex-row fx-gap-1' }, + { 'data-testid': 'edit-actions', class: 'flex-row fx-gap-1' }, Button({ type: 'icon', icon: 'file_copy', tooltip: 'Copy/Move', disabled: !hasSelection, onclick: () => emit('CopyMoveDialogOpened', { payload: isAll ? 'all' : minimalSelected() }) }), Button({ type: 'icon', icon: 'delete', tooltip: 'Delete', disabled: !hasSelection, @@ -462,7 +463,7 @@ const TestDefinitions = (/** @type object */ props) => { } }; return div( - { class: 'flex-row fx-gap-1' }, + { 'data-testid': 'disposition-actions', class: 'flex-row fx-gap-1' }, Button({ type: 'icon', icon: 'check_circle', tooltip: 'Activate selected', disabled: noSelection || allActive, onclick: () => emitAttribute('test_active', true) }), Button({ type: 'icon', icon: 'notifications_off', tooltip: 'Deactivate selected', disabled: noSelection || allInactive, onclick: () => emitAttribute('test_active', false) }), div({ class: 'td-header-separator' }), @@ -758,7 +759,7 @@ const TestDefinitions = (/** @type object */ props) => { const row = singleSelected.val; if (!row) return ''; return div( - { class: 'tg-td--detail flex-column fx-gap-4' }, + { 'data-testid': 'test-definition-detail', class: 'tg-td--detail flex-column fx-gap-4' }, div( { class: 'flex-row fx-gap-2 fx-justify-content-flex-end' }, canEdit.val ? Button({ diff --git a/testgen/ui/components/frontend/js/pages/test_results.js b/testgen/ui/components/frontend/js/pages/test_results.js index 1fcafc6e..8fdc3bab 100644 --- a/testgen/ui/components/frontend/js/pages/test_results.js +++ b/testgen/ui/components/frontend/js/pages/test_results.js @@ -74,6 +74,7 @@ const STATUS_COLORS = { const ClearFlagButton = ({ disabled, onclick }) => { return withTooltip(btn( { + 'data-testid': 'button', class: 'tg-button tg-icon-button tg-basic-button', tooltip: 'Clear flag', disabled, @@ -540,7 +541,7 @@ const TestResults = (/** @type Properties */ props) => { // Table header bar const tableHeader = div( - { class: 'flex-row fx-align-center fx-gap-2 p-2' }, + { 'data-testid': 'table-header', class: 'flex-row fx-align-center fx-gap-2 p-2' }, Toggle({ label: () => { return div( @@ -569,7 +570,7 @@ const TestResults = (/** @type Properties */ props) => { ? !isAll && count === 0 : (() => { const row = selectedRow.val; return !row || row.result_status === 'Passed'; })(); return div( - { class: 'flex-row fx-gap-1' }, + { 'data-testid': 'disposition-actions', class: 'flex-row fx-gap-1' }, Button({ type: 'icon', icon: 'check_circle', tooltip: 'Confirm selected as relevant', disabled, onclick: () => onDisposition('Confirmed') }), Button({ type: 'icon', icon: 'cancel', tooltip: 'Dismiss selected as not relevant', disabled, onclick: () => onDisposition('Dismissed') }), Button({ type: 'icon', icon: 'notifications_off', tooltip: 'Mute selected tests for future runs', disabled, onclick: () => onDisposition('Inactive') }), @@ -599,7 +600,7 @@ const TestResults = (/** @type Properties */ props) => { }; return div( - { class: 'flex-row fx-gap-1' }, + { 'data-testid': 'flag-actions', class: 'flex-row fx-gap-1' }, span({ style: 'width: 0px; height: 24px; border-right: 1px dashed var(--border-color);'}, ''), Button({ type: 'icon', icon: 'flag', tooltip: 'Flag selected', disabled: noSelection, @@ -757,7 +758,7 @@ const TestResults = (/** @type Properties */ props) => { div( { class: 'tg-tr--score flex-column fx-align-center' }, small({ class: 'text-caption' }, 'Score'), - span({ class: 'tg-tr--score-value' }, () => getValue(props.score) ?? '--'), + span({ 'data-testid': 'test-run-score', class: 'tg-tr--score-value' }, () => getValue(props.score) ?? '--'), ), Button({ type: 'icon', @@ -840,7 +841,7 @@ const TestResults = (/** @type Properties */ props) => { const hasData = si && si.test_result_id === row.test_result_id; return div( - { class: 'tg-tr--detail flex-column fx-gap-4' }, + { 'data-testid': 'test-result-detail', class: 'tg-tr--detail flex-column fx-gap-4' }, // Action buttons row div( @@ -879,7 +880,7 @@ const TestResults = (/** @type Properties */ props) => { { class: 'flex-column fx-flex', style: 'min-width: 0' }, h3({ class: 'tg-tr--detail-title' }, row.test_name_short), row.test_description - ? p({ class: 'tg-tr--detail-desc' }, row.test_description) + ? p({ 'data-testid': 'test-result-description', class: 'tg-tr--detail-desc' }, row.test_description) : '', row.measure_uom_description ? small({ class: 'text-caption' }, row.measure_uom_description) diff --git a/testgen/ui/static/js/components/attribute.js b/testgen/ui/static/js/components/attribute.js index a7bb60eb..8e8ab4bc 100644 --- a/testgen/ui/static/js/components/attribute.js +++ b/testgen/ui/static/js/components/attribute.js @@ -19,9 +19,13 @@ const Attribute = (/** @type Properties */ props) => { loadStylesheet('attribute', stylesheet); return div( - { style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}`, class: props.class }, + { + 'data-testid': 'attribute', + style: () => `width: ${props.width ? getValue(props.width) + 'px' : 'auto'}`, + class: props.class, + }, div( - { class: 'flex-row fx-gap-1 text-caption mb-1' }, + { 'data-testid': 'attribute-label', class: 'flex-row fx-gap-1 text-caption mb-1' }, props.label, () => getValue(props.help) ? withTooltip( @@ -31,7 +35,7 @@ const Attribute = (/** @type Properties */ props) => { : null, ), div( - { class: 'attribute-value' }, + { 'data-testid': 'attribute-value', class: 'attribute-value' }, () => { const value = getValue(props.value); if (value === PII_REDACTED) { diff --git a/testgen/ui/static/js/components/card.js b/testgen/ui/static/js/components/card.js index c5fac911..c5d88cba 100644 --- a/testgen/ui/static/js/components/card.js +++ b/testgen/ui/static/js/components/card.js @@ -18,7 +18,7 @@ const Card = (/** @type Properties */ props) => { return div( { id: props.id ?? '', - 'data-testid': 'card', + 'data-testid': props.testId ?? 'card', class: () => { const classes = ['tg-card']; if (getValue(props.border)) { diff --git a/testgen/ui/static/js/components/dialog.js b/testgen/ui/static/js/components/dialog.js index dbff9e80..465c3faf 100644 --- a/testgen/ui/static/js/components/dialog.js +++ b/testgen/ui/static/js/components/dialog.js @@ -51,7 +51,7 @@ const Dialog = ({ title, open, onClose, width = '30rem' }, ...children) => { { class: 'tg-dialog-header' }, span({ 'data-testid': 'dialog-title', class: 'tg-dialog-title' }, title), ), - div({ class: 'tg-dialog-content' }, ...children), + div({ 'data-testid': 'dialog-content', class: 'tg-dialog-content' }, ...children), button( { class: 'tg-dialog-close', diff --git a/testgen/ui/static/js/components/dropdown_button.js b/testgen/ui/static/js/components/dropdown_button.js index 7462141b..b3045466 100644 --- a/testgen/ui/static/js/components/dropdown_button.js +++ b/testgen/ui/static/js/components/dropdown_button.js @@ -49,6 +49,7 @@ const DropdownButton = (props) => { ...items.map(item => div({ class: 'tg-dropdown-button--item', + 'data-testid': 'dropdown-item', style: item.separator ? 'border-top: var(--button-stroked-border);' : '', onclick: () => { menuOpen.val = false; item.onclick(); }, }, item.label), diff --git a/testgen/ui/static/js/components/input.js b/testgen/ui/static/js/components/input.js index 99eab6e0..3939838b 100644 --- a/testgen/ui/static/js/components/input.js +++ b/testgen/ui/static/js/components/input.js @@ -133,6 +133,7 @@ const Input = (/** @type Properties */ props) => { : undefined, () => input({ value, + 'data-value': value, name: props.name ?? '', type: inputType, disabled: props.disabled, diff --git a/testgen/ui/static/js/components/score_legend.js b/testgen/ui/static/js/components/score_legend.js index e5b53281..13265d6d 100644 --- a/testgen/ui/static/js/components/score_legend.js +++ b/testgen/ui/static/js/components/score_legend.js @@ -6,7 +6,7 @@ const { div, span } = van.tags; const ScoreLegend = (/** @type string */ style) => { return div( - { class: 'flex-row fx-gap-3 text-secondary', style }, + { 'data-testid': 'score-legend', class: 'flex-row fx-gap-3 text-secondary', style }, span({ class: 'fx-flex' }), LegendItem('N/A', NaN), LegendItem('0-85', 0), diff --git a/testgen/ui/static/js/components/summary_counts.js b/testgen/ui/static/js/components/summary_counts.js index 918e074c..d307f58e 100644 --- a/testgen/ui/static/js/components/summary_counts.js +++ b/testgen/ui/static/js/components/summary_counts.js @@ -21,11 +21,11 @@ const SummaryCounts = (/** @type Properties */ props) => { return div( { class: 'flex-row fx-gap-5 fx-flex-wrap', 'data-testid': 'summary-counts' }, getValue(props.items).map(item => div( - { class: 'flex-row fx-align-stretch fx-gap-2' }, + { 'data-testid': 'summary-count', class: 'flex-row fx-align-stretch fx-gap-2' }, div({ class: 'tg-summary-counts--bar', style: `background-color: ${colorMap[item.color] || item.color};` }), div( - div({ class: 'text-caption' }, item.label), - div({ class: 'tg-summary-counts--count' }, formatNumber(item.value)), + div({ 'data-testid': 'summary-count-label', class: 'text-caption' }, item.label), + div({ 'data-testid': 'summary-count-value', class: 'tg-summary-counts--count' }, formatNumber(item.value)), ) )), ); diff --git a/testgen/ui/static/js/components/table.js b/testgen/ui/static/js/components/table.js index c4016f54..62a66722 100644 --- a/testgen/ui/static/js/components/table.js +++ b/testgen/ui/static/js/components/table.js @@ -176,6 +176,7 @@ const Table = (options, rows) => { return div( { + 'data-testid': 'table', class: () => `tg-table flex-column border border-radius-1 ${getValue(options.highDensity) ? 'tg-table-high-density' : ''} ${getValue(options.dynamicWidth) ? 'tg-table-dynamic-width' : ''} ${(getValue(options.uppercaseHeader) ?? true) ? 'tg-table-uppercase-header' : ''} ${options.selection?.onRowsSelected ? 'tg-table-hoverable' : ''}`, style: () => `height: ${getValue(options.height) ? getValue(options.height) : defaultHeight}; ${getValue(options.maxHeight) ? 'max-height: ' + getValue(options.maxHeight) + ';' : ''}`, }, @@ -222,7 +223,7 @@ const Table = (options, rows) => { const rows_ = getValue(rows); if (rows_.length <= 0 && options.emptyState) { return tbody( - {class: 'tg-table-empty-state-body'}, + {'data-testid': 'table-empty', class: 'tg-table-empty-state-body'}, tr( td( {colspan: dataColumns.val.length}, @@ -413,7 +414,7 @@ const Paginatior = ( const sizeOptions = (pageSizeOptions ?? defaultPageSizeOptions).map(n => ({ label: String(n), value: n })); return div( - {class: `tg-table-paginator flex-row fx-justify-content-flex-end ${highDensity ? '' : 'p-1'} text-secondary`}, + {'data-testid': 'table-paginator', class: `tg-table-paginator flex-row fx-justify-content-flex-end ${highDensity ? '' : 'p-1'} text-secondary`}, leftContent, leftContent != undefined ? span({class: 'fx-flex'}) : '', diff --git a/testgen/ui/static/js/components/table_group_wizard.js b/testgen/ui/static/js/components/table_group_wizard.js index 05873061..03610738 100644 --- a/testgen/ui/static/js/components/table_group_wizard.js +++ b/testgen/ui/static/js/components/table_group_wizard.js @@ -461,7 +461,7 @@ const TableGroupWizard = (props) => { return div( { class: 'flex-column' }, div( - { class: 'flex-column fx-gap-4 mb-4 p-5 border border-radius-2' }, + { class: 'flex-column fx-gap-4 mb-4 p-5 border border-radius-2', 'data-testid': 'wizard-success-panel' }, div( { class: 'flex-row fx-gap-2' }, Icon({ style: 'color: var(--green);' }, 'check_circle'), diff --git a/testgen/ui/static/js/components/textarea.js b/testgen/ui/static/js/components/textarea.js index 9e08f9e6..5a004a7b 100644 --- a/testgen/ui/static/js/components/textarea.js +++ b/testgen/ui/static/js/components/textarea.js @@ -86,6 +86,7 @@ const Textarea = (/** @type Properties */ props) => { class: () => `tg-textarea--field ${getValue(props.disabled) ? 'tg-textarea--disabled' : ''}`, style: () => `min-height: ${getValue(props.height) || defaultHeight}px;`, value, + 'data-value': value, name: props.name ?? '', disabled: props.disabled, placeholder: () => getValue(props.placeholder) ?? '', From 42573690f779b95f6212d269fda56f386a912fe5 Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Thu, 7 May 2026 12:55:18 -0400 Subject: [PATCH 03/58] feat(server): harden API + MCP server for production deployments (TG-1065) Four bundled fixes for production readiness: - DNS rebinding: pass explicit transport_security to FastMCP with allowlist derived from BASE_URL + loopback + TG_MCP_EXTRA_ALLOWED_HOSTS. Fixes 421 Misdirected Request for external clients caused by FastMCP's loopback-only auto-allowlist. - Security headers: pure-ASGI SecurityHeadersMiddleware injects HSTS (TLS-only by default), X-Content-Type-Options, Referrer-Policy, and CSP frame-ancestors on success and error responses across /api/*, /oauth/*, /.well-known/*, /mcp. - Body-size cap: pure-ASGI BodySizeLimitMiddleware rejects requests exceeding TG_API_MAX_REQUEST_BODY_BYTES (default 10 MiB) with 413, enforced via Content-Length fast-reject and a streaming guard with a latch to prevent post-disconnect bypass. - Graceful shutdown: timeout_graceful_shutdown plumbed to uvicorn.run via TG_API_GRACEFUL_SHUTDOWN_TIMEOUT (default 30s). All settings env-overridable. Pure-ASGI middlewares chosen over BaseHTTPMiddleware to preserve MCP's text/event-stream transport. Tests: 11 cases for the two middlewares (covers latch regression), 7 cases for the transport_security helper. Co-Authored-By: Claude Opus 4.7 (1M context) --- testgen/mcp/server.py | 41 +++ testgen/server/__init__.py | 25 +- testgen/server/middleware.py | 114 ++++++++ testgen/settings.py | 62 +++++ tests/unit/mcp/test_transport_security.py | 86 ++++++ tests/unit/server/test_middleware.py | 320 ++++++++++++++++++++++ 6 files changed, 647 insertions(+), 1 deletion(-) create mode 100644 testgen/server/middleware.py create mode 100644 tests/unit/mcp/test_transport_security.py create mode 100644 tests/unit/server/test_middleware.py diff --git a/testgen/mcp/server.py b/testgen/mcp/server.py index 1358db55..4d53450b 100644 --- a/testgen/mcp/server.py +++ b/testgen/mcp/server.py @@ -1,11 +1,14 @@ import logging +from urllib.parse import urlparse from mcp.server.auth.provider import AccessToken from mcp.server.auth.settings import AuthSettings from mcp.server.fastmcp import FastMCP from mcp.server.streamable_http_manager import StreamableHTTPSessionManager +from mcp.server.transport_security import TransportSecuritySettings from starlette.applications import Starlette +from testgen import settings from testgen.common.auth import decode_jwt_token from testgen.mcp.permissions import set_mcp_token, set_mcp_username @@ -75,6 +78,43 @@ def _configure_mcp_logging() -> None: logging.getLogger(name).parent = testgen_logger +def _build_transport_security() -> TransportSecuritySettings: + """Build DNS-rebinding allowlist from BASE_URL plus operator extras and loopback. + + Without an explicit transport_security, FastMCP installs a loopback-only + allowlist that rejects external Host headers with 421. We pass this settings + object so production deployments accept their own externally-reachable host. + """ + parsed = urlparse(settings.BASE_URL) + base_host = parsed.hostname or "localhost" + netloc = parsed.netloc + scheme = parsed.scheme or "http" + + allowed_hosts: set[str] = { + netloc, + f"{base_host}:*", + "127.0.0.1:*", + "localhost:*", + "[::1]:*", + } + allowed_origins: set[str] = { + f"{scheme}://{netloc}", + "http://127.0.0.1:*", "https://127.0.0.1:*", + "http://localhost:*", "https://localhost:*", + "http://[::1]:*", "https://[::1]:*", + } + for host in settings.MCP_EXTRA_ALLOWED_HOSTS: + host_pattern = host if ":" in host else f"{host}:*" + allowed_hosts.add(host_pattern) + allowed_origins.update({f"http://{host_pattern}", f"https://{host_pattern}"}) + + return TransportSecuritySettings( + enable_dns_rebinding_protection=True, + allowed_hosts=sorted(allowed_hosts), + allowed_origins=sorted(allowed_origins), + ) + + def build_mcp_server( api_base_url: str, server_url: str | None = None, @@ -138,6 +178,7 @@ def build_mcp_server( resource_server_url=server_url, ), token_verifier=JWTTokenVerifier(), + transport_security=_build_transport_security(), ) _configure_mcp_logging() diff --git a/testgen/server/__init__.py b/testgen/server/__init__.py index 120a7789..7109c420 100644 --- a/testgen/server/__init__.py +++ b/testgen/server/__init__.py @@ -27,6 +27,7 @@ from testgen.api.test_definitions import router as test_definitions_router from testgen.common import version_service from testgen.common.models import with_database_session +from testgen.server.middleware import BodySizeLimitMiddleware, SecurityHeadersMiddleware LOG = logging.getLogger("testgen") @@ -131,6 +132,21 @@ def favicon(): if settings.MCP_ENABLED: app.mount("", mcp_app) + # add_middleware is LIFO — body cap is added first so it runs innermost, + # rejecting oversized requests before security headers wrap the 413 response + app.add_middleware(BodySizeLimitMiddleware, max_bytes=settings.API_MAX_REQUEST_BODY_BYTES) + + hsts = settings.API_HSTS_HEADER or ( + "max-age=63072000; includeSubDomains" if settings.API_TLS_ENABLED else None + ) + app.add_middleware( + SecurityHeadersMiddleware, + hsts=hsts, + csp=settings.API_CSP_HEADER, + referrer=settings.API_REFERRER_POLICY, + nosniff=True, + ) + if settings.IS_DEBUG: from starlette.middleware.cors import CORSMiddleware @@ -171,4 +187,11 @@ def run_server() -> None: "enabled" if settings.API_TLS_ENABLED else "disabled", "enabled" if settings.MCP_ENABLED else "disabled", ) - uvicorn.run(app, host=settings.API_HOST, port=settings.API_PORT, log_level="info", **ssl_kwargs) + uvicorn.run( + app, + host=settings.API_HOST, + port=settings.API_PORT, + log_level="info", + timeout_graceful_shutdown=settings.API_GRACEFUL_SHUTDOWN_TIMEOUT, + **ssl_kwargs, + ) diff --git a/testgen/server/middleware.py b/testgen/server/middleware.py new file mode 100644 index 00000000..56661a6b --- /dev/null +++ b/testgen/server/middleware.py @@ -0,0 +1,114 @@ +"""ASGI middlewares for the combined FastAPI + MCP server. + +These are pure-ASGI implementations (not BaseHTTPMiddleware) to avoid buffering +responses, which would break MCP's text/event-stream transport. +""" + +from starlette.types import ASGIApp, Message, Receive, Scope, Send + +_413_BODY = b'{"detail":"Request body too large"}' + + +async def _send_413(send: Send) -> None: + await send( + { + "type": "http.response.start", + "status": 413, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(_413_BODY)).encode()), + ], + } + ) + await send({"type": "http.response.body", "body": _413_BODY}) + + +class BodySizeLimitMiddleware: + """Reject requests whose body exceeds *max_bytes* with HTTP 413. + + Checks Content-Length up front when present; otherwise tracks accumulated + body bytes and disconnects when the limit is exceeded mid-stream. Only + inspects http.request messages, so MCP SSE response streams pass through + untouched. + """ + + def __init__(self, app: ASGIApp, max_bytes: int) -> None: + self.app = app + self.max_bytes = max_bytes + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http" or scope.get("method") in ("GET", "HEAD", "OPTIONS"): + await self.app(scope, receive, send) + return + + content_length = next( + (v for k, v in scope.get("headers", []) if k == b"content-length"), None + ) + if content_length is not None: + try: + if int(content_length) > self.max_bytes: + await _send_413(send) + return + except ValueError: + pass + + received = 0 + exceeded = False + + async def limited_receive() -> Message: + nonlocal received, exceeded + if exceeded: + return {"type": "http.disconnect"} + message = await receive() + if message["type"] == "http.request": + received += len(message.get("body", b"")) + if received > self.max_bytes: + exceeded = True + return {"type": "http.disconnect"} + return message + + await self.app(scope, limited_receive, send) + + +class SecurityHeadersMiddleware: + """Inject standard security headers on every HTTP response. + + Headers are added to http.response.start, so they apply uniformly to success + and error responses. Existing headers (case-insensitive match) are preserved, + letting per-route handlers override defaults. + """ + + def __init__( + self, + app: ASGIApp, + *, + hsts: str | None, + csp: str, + referrer: str, + nosniff: bool, + ) -> None: + self.app = app + self.headers: list[tuple[bytes, bytes]] = [] + if hsts: + self.headers.append((b"strict-transport-security", hsts.encode())) + if nosniff: + self.headers.append((b"x-content-type-options", b"nosniff")) + if referrer: + self.headers.append((b"referrer-policy", referrer.encode())) + if csp: + self.headers.append((b"content-security-policy", csp.encode())) + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + async def send_wrapper(message: Message) -> None: + if message["type"] == "http.response.start": + existing = {k.lower() for k, _ in message.get("headers", [])} + for name, value in self.headers: + if name not in existing: + message["headers"].append((name, value)) + await send(message) + + await self.app(scope, receive, send_wrapper) diff --git a/testgen/settings.py b/testgen/settings.py index 94c4e002..339592cf 100644 --- a/testgen/settings.py +++ b/testgen/settings.py @@ -618,3 +618,65 @@ def _default_ui_base_url() -> str: from env variable: `TG_UI_BASE_URL` defaults to: computed from UI_TLS_ENABLED and UI_PORT """ + +MCP_EXTRA_ALLOWED_HOSTS: list[str] = [ + h.strip() for h in (getenv("TG_MCP_EXTRA_ALLOWED_HOSTS", "") or "").split(",") if h.strip() +] +""" +Extra Host header values accepted by MCP DNS rebinding protection (comma-separated). +BASE_URL's hostname and loopback are always allowed; this adds more for multi-domain +deployments or reverse proxies that rewrite Host. Entries without a port (`tg.example.com`) +get an automatic `:*` wildcard; entries with a port are matched literally +(`tg.example.com:8080`) or with explicit wildcard (`tg.example.com:*`). +Only affects MCP routes — the parent FastAPI app does not validate Host headers. + +from env variable: `TG_MCP_EXTRA_ALLOWED_HOSTS` +defaults to: empty (BASE_URL hostname + loopback only) +""" + +API_MAX_REQUEST_BODY_BYTES: int = int( + getenv("TG_API_MAX_REQUEST_BODY_BYTES", str(10 * 1024 * 1024)) +) +""" +Reject HTTP requests larger than this with 413 Payload Too Large. + +from env variable: `TG_API_MAX_REQUEST_BODY_BYTES` +defaults to: 10485760 (10 MiB) +""" + +API_GRACEFUL_SHUTDOWN_TIMEOUT: int = int(getenv("TG_API_GRACEFUL_SHUTDOWN_TIMEOUT", "30")) +""" +Seconds uvicorn waits for in-flight requests on SIGTERM before force-closing. +Long blocking SQL queries that don't honor asyncio cancellation may be cut mid-flight; +align with the target DB's statement_timeout. + +from env variable: `TG_API_GRACEFUL_SHUTDOWN_TIMEOUT` +defaults to: 30 +""" + +API_HSTS_HEADER: str = getenv("TG_API_HSTS_HEADER", "") +""" +Override HSTS (Strict-Transport-Security) header value. When empty, HSTS is emitted +only when API_TLS_ENABLED with value 'max-age=63072000; includeSubDomains'. Setting +this forces emission regardless of TLS (useful when TLS terminates at a reverse proxy). + +from env variable: `TG_API_HSTS_HEADER` +defaults to: empty (auto from API_TLS_ENABLED) +""" + +API_CSP_HEADER: str = getenv("TG_API_CSP_HEADER", "frame-ancestors 'none'") +""" +Content-Security-Policy header value. Default restricts framing only; broader policies +risk breaking Redoc at /api/docs which loads CDN assets. + +from env variable: `TG_API_CSP_HEADER` +defaults to: `frame-ancestors 'none'` +""" + +API_REFERRER_POLICY: str = getenv("TG_API_REFERRER_POLICY", "no-referrer") +""" +Referrer-Policy header value. + +from env variable: `TG_API_REFERRER_POLICY` +defaults to: `no-referrer` +""" diff --git a/tests/unit/mcp/test_transport_security.py b/tests/unit/mcp/test_transport_security.py new file mode 100644 index 00000000..22b101ec --- /dev/null +++ b/tests/unit/mcp/test_transport_security.py @@ -0,0 +1,86 @@ +"""Tests for testgen.mcp.server._build_transport_security — DNS rebinding allowlist builder.""" + +from unittest.mock import patch + +from testgen.mcp.server import _build_transport_security + + +def _build_with(base_url: str, extras: list[str] | None = None): + with ( + patch("testgen.mcp.server.settings.BASE_URL", base_url), + patch("testgen.mcp.server.settings.MCP_EXTRA_ALLOWED_HOSTS", extras or []), + ): + return _build_transport_security() + + +def test_loopback_and_base_url_always_present(): + """With no extras, the allowlist is BASE_URL hosts + loopback variants.""" + settings = _build_with("http://tg.example.com:8530") + + assert settings.enable_dns_rebinding_protection is True + assert "tg.example.com:8530" in settings.allowed_hosts + assert "tg.example.com:*" in settings.allowed_hosts + assert "127.0.0.1:*" in settings.allowed_hosts + assert "localhost:*" in settings.allowed_hosts + assert "[::1]:*" in settings.allowed_hosts + + assert "http://tg.example.com:8530" in settings.allowed_origins + # Loopback origins covered for both schemes + assert "http://localhost:*" in settings.allowed_origins + assert "https://localhost:*" in settings.allowed_origins + + +def test_extra_host_without_port_gets_wildcard(): + """An extras entry without `:` gets `:*` automatically appended.""" + settings = _build_with("http://localhost:8530", extras=["tg.example.com"]) + + assert "tg.example.com:*" in settings.allowed_hosts + assert "tg.example.com" not in settings.allowed_hosts # bare entry should NOT be present + assert "http://tg.example.com:*" in settings.allowed_origins + assert "https://tg.example.com:*" in settings.allowed_origins + + +def test_extra_host_with_explicit_port_preserved_literally(): + """An extras entry with an explicit port is kept as-is, no wildcard appended.""" + settings = _build_with("http://localhost:8530", extras=["tg.example.com:8080"]) + + assert "tg.example.com:8080" in settings.allowed_hosts + assert "tg.example.com:8080:*" not in settings.allowed_hosts # no double-port + + assert "http://tg.example.com:8080" in settings.allowed_origins + assert "https://tg.example.com:8080" in settings.allowed_origins + + +def test_extra_host_with_explicit_wildcard_preserved(): + """An extras entry with `:*` is kept as-is.""" + settings = _build_with("http://localhost:8530", extras=["tg.example.com:*"]) + + assert "tg.example.com:*" in settings.allowed_hosts + assert "http://tg.example.com:*" in settings.allowed_origins + + +def test_mixed_extras(): + """Multiple extras with different shapes are all handled correctly.""" + settings = _build_with( + "http://localhost:8530", + extras=["foo.com", "bar.io:9000", "baz.net:*"], + ) + + assert "foo.com:*" in settings.allowed_hosts + assert "bar.io:9000" in settings.allowed_hosts + assert "baz.net:*" in settings.allowed_hosts + + +def test_https_base_url_origin_uses_https_scheme(): + """Origin scheme tracks BASE_URL's scheme.""" + settings = _build_with("https://tg.example.com") + + assert "https://tg.example.com" in settings.allowed_origins + + +def test_results_are_sorted_lists(): + """allowed_hosts and allowed_origins are deterministic (sorted) for stable diffs.""" + settings = _build_with("http://localhost:8530", extras=["zeta.com", "alpha.com"]) + + assert settings.allowed_hosts == sorted(settings.allowed_hosts) + assert settings.allowed_origins == sorted(settings.allowed_origins) diff --git a/tests/unit/server/test_middleware.py b/tests/unit/server/test_middleware.py new file mode 100644 index 00000000..2225deb2 --- /dev/null +++ b/tests/unit/server/test_middleware.py @@ -0,0 +1,320 @@ +"""Tests for testgen.server.middleware — pure-ASGI body cap and security headers.""" + +# ASGI test stubs (receive/send/inner-app) must be async per protocol but don't +# await anything in these tests. RUF029 is a false positive for that pattern. +# ruff: noqa: RUF029 + +import asyncio +import json + +from testgen.server.middleware import BodySizeLimitMiddleware, SecurityHeadersMiddleware + + +def _http_scope(method: str = "POST", headers: list[tuple[bytes, bytes]] | None = None) -> dict: + return {"type": "http", "method": method, "headers": headers or []} + + +# -------------------------- BodySizeLimitMiddleware -------------------------- + + +def test_body_cap_content_length_over_limit_rejects_immediately(): + """Content-Length > max_bytes → 413 sent without invoking the inner app.""" + inner_called = False + + async def inner(scope, receive, send): + nonlocal inner_called + inner_called = True + + mw = BodySizeLimitMiddleware(inner, max_bytes=1024) + scope = _http_scope(headers=[(b"content-length", b"2048")]) + + sent: list[dict] = [] + + async def send(msg): + sent.append(msg) + + async def receive(): + return {"type": "http.disconnect"} + + asyncio.run(mw(scope, receive, send)) + + assert not inner_called + assert sent[0]["type"] == "http.response.start" + assert sent[0]["status"] == 413 + assert json.loads(sent[1]["body"]) == {"detail": "Request body too large"} + + +def test_body_cap_content_length_under_limit_passes_through(): + """Content-Length under the limit → inner app runs normally.""" + received_by_inner: list[dict] = [] + + async def inner(scope, receive, send): + received_by_inner.append(await receive()) + await send({"type": "http.response.start", "status": 200, "headers": []}) + + mw = BodySizeLimitMiddleware(inner, max_bytes=1024) + scope = _http_scope(headers=[(b"content-length", b"100")]) + + queued = [{"type": "http.request", "body": b"x" * 100, "more_body": False}] + sent: list[dict] = [] + + async def send(msg): + sent.append(msg) + + async def receive(): + return queued.pop(0) if queued else {"type": "http.disconnect"} + + asyncio.run(mw(scope, receive, send)) + + assert received_by_inner[0]["body"] == b"x" * 100 + assert sent[0]["status"] == 200 + + +def test_body_cap_streaming_disconnects_when_exceeded(): + """Without Content-Length, accumulating body chunks past the limit returns disconnect.""" + received_by_inner: list[dict] = [] + + async def inner(scope, receive, send): + # Drain three chunks: third one pushes past the limit + for _ in range(3): + received_by_inner.append(await receive()) + + mw = BodySizeLimitMiddleware(inner, max_bytes=150) + scope = _http_scope(headers=[]) + + queued = [ + {"type": "http.request", "body": b"x" * 100, "more_body": True}, + {"type": "http.request", "body": b"y" * 100, "more_body": True}, + {"type": "http.request", "body": b"z" * 100, "more_body": False}, + ] + + async def send(msg): + pass + + async def receive(): + return queued.pop(0) if queued else {"type": "http.disconnect"} + + asyncio.run(mw(scope, receive, send)) + + # First chunk passes (100 bytes < 150). Second chunk pushes total to 200, exceeds, returns disconnect. + assert received_by_inner[0]["body"] == b"x" * 100 + assert received_by_inner[1]["type"] == "http.disconnect" + + +def test_body_cap_latch_holds_across_repeated_receives(): + """Regression: once exceeded, every subsequent receive() returns disconnect. + + Without the latch, an inner app that drains receive() multiple times after + seeing http.disconnect could read more body bytes from the underlying socket, + bypassing the cap. + """ + received_by_inner: list[dict] = [] + + async def inner(scope, receive, send): + # Drain 5 times, well past the disconnect + for _ in range(5): + received_by_inner.append(await receive()) + + mw = BodySizeLimitMiddleware(inner, max_bytes=50) + scope = _http_scope(headers=[]) + + queued = [ + {"type": "http.request", "body": b"x" * 100, "more_body": True}, # exceeds immediately + {"type": "http.request", "body": b"y" * 100, "more_body": True}, # would exceed again if reached + {"type": "http.request", "body": b"z" * 100, "more_body": False}, + ] + + async def send(msg): + pass + + async def receive(): + return queued.pop(0) if queued else {"type": "http.disconnect"} + + asyncio.run(mw(scope, receive, send)) + + # First call: real chunk (100 bytes), exceeds → returns disconnect + assert received_by_inner[0]["type"] == "http.disconnect" + # Subsequent calls: latch keeps returning disconnect, never forwards real chunks + for msg in received_by_inner[1:]: + assert msg["type"] == "http.disconnect" + + +def test_body_cap_get_request_bypasses(): + """GET requests skip the cap — no body to inspect.""" + received_by_inner: list[dict] = [] + + async def inner(scope, receive, send): + received_by_inner.append("called") + + mw = BodySizeLimitMiddleware(inner, max_bytes=100) + scope = _http_scope(method="GET", headers=[(b"content-length", b"99999")]) + + async def send(msg): + pass + + async def receive(): + return {"type": "http.disconnect"} + + asyncio.run(mw(scope, receive, send)) + + assert received_by_inner == ["called"] # inner ran despite huge Content-Length + + +def test_body_cap_non_http_scope_passes_through(): + """Lifespan/websocket scopes bypass entirely.""" + inner_called = False + + async def inner(scope, receive, send): + nonlocal inner_called + inner_called = True + + mw = BodySizeLimitMiddleware(inner, max_bytes=10) + scope = {"type": "lifespan"} + + async def send(msg): + pass + + async def receive(): + return {"type": "lifespan.shutdown"} + + asyncio.run(mw(scope, receive, send)) + + assert inner_called + + +def test_body_cap_malformed_content_length_falls_through_to_streaming(): + """Non-numeric Content-Length doesn't crash; streaming guard still applies.""" + received_by_inner: list[dict] = [] + + async def inner(scope, receive, send): + received_by_inner.append(await receive()) + + mw = BodySizeLimitMiddleware(inner, max_bytes=50) + scope = _http_scope(headers=[(b"content-length", b"not-a-number")]) + + queued = [{"type": "http.request", "body": b"x" * 100, "more_body": False}] + + async def send(msg): + pass + + async def receive(): + return queued.pop(0) if queued else {"type": "http.disconnect"} + + asyncio.run(mw(scope, receive, send)) + + # Streaming guard catches the oversized body + assert received_by_inner[0]["type"] == "http.disconnect" + + +# -------------------------- SecurityHeadersMiddleware -------------------------- + + +def test_security_headers_added_to_response_start(): + """All configured headers are injected on http.response.start.""" + async def inner(scope, receive, send): + await send({"type": "http.response.start", "status": 200, "headers": []}) + await send({"type": "http.response.body", "body": b""}) + + mw = SecurityHeadersMiddleware( + inner, + hsts="max-age=63072000", + csp="frame-ancestors 'none'", + referrer="no-referrer", + nosniff=True, + ) + scope = _http_scope(method="GET") + sent: list[dict] = [] + + async def send(msg): + sent.append(msg) + + async def receive(): + return {"type": "http.disconnect"} + + asyncio.run(mw(scope, receive, send)) + + headers = dict(sent[0]["headers"]) + assert headers[b"strict-transport-security"] == b"max-age=63072000" + assert headers[b"content-security-policy"] == b"frame-ancestors 'none'" + assert headers[b"referrer-policy"] == b"no-referrer" + assert headers[b"x-content-type-options"] == b"nosniff" + + +def test_security_headers_preserve_handler_set_value(): + """If the handler already sets CSP, the middleware does not override it. + + Case-insensitive: handler-set 'Content-Security-Policy' wins over middleware's lowercase form. + """ + async def inner(scope, receive, send): + await send({ + "type": "http.response.start", + "status": 200, + "headers": [(b"Content-Security-Policy", b"default-src 'self'")], + }) + + mw = SecurityHeadersMiddleware( + inner, + hsts=None, + csp="frame-ancestors 'none'", + referrer="no-referrer", + nosniff=True, + ) + scope = _http_scope(method="GET") + sent: list[dict] = [] + + async def send(msg): + sent.append(msg) + + async def receive(): + return {"type": "http.disconnect"} + + asyncio.run(mw(scope, receive, send)) + + csp_values = [v for k, v in sent[0]["headers"] if k.lower() == b"content-security-policy"] + assert csp_values == [b"default-src 'self'"] + + +def test_security_headers_skip_hsts_when_none(): + """hsts=None → no HSTS header emitted (the API_TLS_ENABLED=False default path).""" + async def inner(scope, receive, send): + await send({"type": "http.response.start", "status": 200, "headers": []}) + + mw = SecurityHeadersMiddleware( + inner, hsts=None, csp="frame-ancestors 'none'", referrer="no-referrer", nosniff=True, + ) + scope = _http_scope(method="GET") + sent: list[dict] = [] + + async def send(msg): + sent.append(msg) + + async def receive(): + return {"type": "http.disconnect"} + + asyncio.run(mw(scope, receive, send)) + + header_names = {k.lower() for k, _ in sent[0]["headers"]} + assert b"strict-transport-security" not in header_names + + +def test_security_headers_non_http_scope_passes_through(): + """Lifespan and other non-http scopes are unmodified.""" + inner_called = False + + async def inner(scope, receive, send): + nonlocal inner_called + inner_called = True + + mw = SecurityHeadersMiddleware( + inner, hsts=None, csp="frame-ancestors 'none'", referrer="no-referrer", nosniff=True, + ) + + async def send(msg): + pass + + async def receive(): + return {"type": "lifespan.shutdown"} + + asyncio.run(mw({"type": "lifespan"}, receive, send)) + + assert inner_called From ac8baa2509fcca2196b7e9defc87c85ed5e96cb2 Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Fri, 8 May 2026 15:15:19 -0400 Subject: [PATCH 04/58] feat(mcp): add run status & history tools (TG-1050) Adds list_profiling_runs, get_profiling_run, and get_test_run. Renames get_recent_test_runs -> list_test_runs (adding status and table_group_id filters) and get_test_result_history -> list_test_result_history (TG-1036). Pending/queued JEs surface in a dedicated "Pending" section when scoped by suite or table group via the new JobExecution.select_active_by_kwargs helper. The same kwargs-search pattern is added on JobSchedule for the "Next scheduled run" lookup. select_summary on TestRun and ProfilingRun gains job_execution_id and statuses filters; ProfilingRunSummary now exposes project_code so the by-id tools no longer need a second query. ProfilingRun.select_table_breakdown is the per-table breakdown used by get_profiling_run, written in ORM. Co-Authored-By: Claude Opus 4.7 (1M context) --- testgen/common/models/job_execution.py | 35 ++- testgen/common/models/profiling_run.py | 80 +++++- testgen/common/models/scheduler.py | 25 ++ testgen/common/models/test_run.py | 10 + testgen/mcp/prompts/workflows.py | 4 +- testgen/mcp/server.py | 19 +- testgen/mcp/tools/common.py | 49 +++- testgen/mcp/tools/execution.py | 8 +- testgen/mcp/tools/profiling.py | 192 ++++++++++++- testgen/mcp/tools/test_results.py | 12 +- testgen/mcp/tools/test_runs.py | 326 ++++++++++++++++++---- tests/unit/mcp/test_tools_execution.py | 4 +- tests/unit/mcp/test_tools_profiling.py | 195 +++++++++++++ tests/unit/mcp/test_tools_test_results.py | 24 +- tests/unit/mcp/test_tools_test_runs.py | 305 +++++++++++++------- 15 files changed, 1108 insertions(+), 180 deletions(-) diff --git a/testgen/common/models/job_execution.py b/testgen/common/models/job_execution.py index 49aa67b9..24a23cbc 100644 --- a/testgen/common/models/job_execution.py +++ b/testgen/common/models/job_execution.py @@ -1,7 +1,7 @@ import logging from datetime import UTC, datetime from enum import StrEnum -from typing import Any, Self +from typing import Any, ClassVar, Self from uuid import UUID, uuid4 from sqlalchemy import Column, String, Text, case, func, select, text, update @@ -101,6 +101,39 @@ def claim_actionable(cls, limit: int = 5) -> list[Self]: LOG.info("Claimed %d pending job execution(s)", claimed) return rows + _ACTIVE_STATUSES: ClassVar[list[JobStatus]] = [ + JobStatus.PENDING, JobStatus.CLAIMED, JobStatus.RUNNING, JobStatus.CANCEL_REQUESTED, + ] + + @classmethod + def select_active_by_kwargs( + cls, + project_code: str, + job_key: str, + kwargs_match: dict[str, str | list[str]], + statuses: list[JobStatus] | None = None, + ) -> list[Self]: + """Find JE rows whose ``kwargs`` JSONB matches the given (key, value) pairs. + + Values may be a single string or a list of strings (which becomes an ``IN`` filter). + Defaults to active (non-terminal) statuses. + """ + statuses = statuses or cls._ACTIVE_STATUSES + query = select(cls).where( + cls.project_code == project_code, + cls.job_key == job_key, + cls.status.in_(statuses), + ) + for k, v in kwargs_match.items(): + if isinstance(v, list): + if not v: + return [] + query = query.where(cls.kwargs[k].astext.in_([str(x) for x in v])) + else: + query = query.where(cls.kwargs[k].astext == str(v)) + query = query.order_by(cls.created_at.desc()) + return list(get_current_session().scalars(query).all()) + @classmethod def find_stale(cls) -> list[Self]: """Return job executions left in non-terminal states from a previous process.""" diff --git a/testgen/common/models/profiling_run.py b/testgen/common/models/profiling_run.py index 05a5a94f..cbdf98b1 100644 --- a/testgen/common/models/profiling_run.py +++ b/testgen/common/models/profiling_run.py @@ -15,6 +15,7 @@ from testgen.common.models.connection import Connection from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.models.profile_result import ProfileResult from testgen.common.models.project import Project from testgen.common.models.table_group import TableGroup from testgen.utils import is_uuid4 @@ -46,6 +47,7 @@ class ProfilingRunMinimal(EntityMinimal): class ProfilingRunSummary(EntityMinimal): job_execution_id: UUID profiling_run_id: UUID | None + project_code: str status: JobStatus created_at: datetime started_at: datetime | None @@ -83,6 +85,15 @@ def status_label(self) -> str: return self.STATUS_LABEL.get(self.status, self.status) +@dataclass +class ProfilingRunTableBreakdown(EntityMinimal): + schema_name: str + table_name: str + record_ct: int | None + column_ct: int + anomaly_ct: int + + class LatestProfilingRun(NamedTuple): id: str run_time: datetime @@ -196,14 +207,21 @@ def select_minimal_where( @classmethod def select_summary( cls, - project_code: str, + project_code: str | None = None, table_group_id: str | UUID | None = None, + job_execution_id: str | UUID | None = None, + statuses: list[JobStatus] | None = None, page: int = 1, page_size: int = 20, ) -> tuple[list[ProfilingRunSummary], int]: - if table_group_id and not is_uuid4(table_group_id): + if ( + (table_group_id and not is_uuid4(table_group_id)) + or (job_execution_id and not is_uuid4(job_execution_id)) + ): return [], 0 + # Pending JEs (no pr row) surface in project-scope queries via the LEFT JOIN, but + # not in table-group-scoped queries, since the WHERE filter requires tg to match. query = f""" WITH profile_anomalies AS ( SELECT profile_anomaly_results.profile_run_id, @@ -224,6 +242,7 @@ def select_summary( SELECT je.id AS job_execution_id, pr.id AS profiling_run_id, + je.project_code, je.status, je.created_at, je.started_at, @@ -250,14 +269,18 @@ def select_summary( LEFT JOIN table_groups tg ON tg.id = pr.table_groups_id LEFT JOIN profile_anomalies pa ON pa.profile_run_id = pr.id WHERE je.job_key = 'run-profile' - AND je.project_code = :project_code + {" AND je.project_code = :project_code" if project_code else ""} {" AND tg.id = :table_group_id" if table_group_id else ""} + {" AND je.id = :job_execution_id" if job_execution_id else ""} + {" AND je.status IN :statuses" if statuses else ""} ORDER BY je.created_at DESC LIMIT :limit OFFSET :offset; """ params = { "project_code": project_code, - "table_group_id": table_group_id, + "table_group_id": str(table_group_id) if table_group_id else None, + "job_execution_id": str(job_execution_id) if job_execution_id else None, + "statuses": tuple(statuses) if statuses else (), "limit": page_size, "offset": (page - 1) * page_size, } @@ -267,6 +290,55 @@ def select_summary( total = items[0].total_count if items else 0 return items, total + @classmethod + def select_table_breakdown(cls, profiling_run_id: UUID) -> list[ProfilingRunTableBreakdown]: + """Per-table breakdown for a completed profiling run: schema, table, record/column count, anomaly count.""" + # HygieneIssue imports ProfilingRun, so this import has to stay function-local. + from testgen.common.models.hygiene_issue import HygieneIssue + + results_subq = ( + select( + ProfileResult.schema_name.label("schema_name"), + ProfileResult.table_name.label("table_name"), + func.max(ProfileResult.record_ct).label("record_ct"), + func.count(func.distinct(ProfileResult.column_name)).label("column_ct"), + ) + .where(ProfileResult.profile_run_id == profiling_run_id) + .group_by(ProfileResult.schema_name, ProfileResult.table_name) + .subquery() + ) + anomalies_subq = ( + select( + HygieneIssue.schema_name.label("schema_name"), + HygieneIssue.table_name.label("table_name"), + func.count().label("anomaly_ct"), + ) + .where( + HygieneIssue.profile_run_id == profiling_run_id, + func.coalesce(HygieneIssue.disposition, "Confirmed") == "Confirmed", + ) + .group_by(HygieneIssue.schema_name, HygieneIssue.table_name) + .subquery() + ) + query = ( + select( + results_subq.c.schema_name, + results_subq.c.table_name, + results_subq.c.record_ct, + results_subq.c.column_ct, + func.coalesce(anomalies_subq.c.anomaly_ct, 0).label("anomaly_ct"), + ) + .select_from(results_subq) + .outerjoin( + anomalies_subq, + (anomalies_subq.c.schema_name == results_subq.c.schema_name) + & (anomalies_subq.c.table_name == results_subq.c.table_name), + ) + .order_by(results_subq.c.schema_name, results_subq.c.table_name) + ) + rows = get_current_session().execute(query).mappings().all() + return [ProfilingRunTableBreakdown(**row) for row in rows] + _ACTIVE_JOB_STATUSES = (JobStatus.PENDING, JobStatus.CLAIMED, JobStatus.RUNNING, JobStatus.CANCEL_REQUESTED) @classmethod diff --git a/testgen/common/models/scheduler.py b/testgen/common/models/scheduler.py index f094c4ab..825766a8 100644 --- a/testgen/common/models/scheduler.py +++ b/testgen/common/models/scheduler.py @@ -73,6 +73,31 @@ def update_active(cls, job_id: str | UUID, active: bool) -> None: @classmethod def count(cls): return get_current_session().query(cls).count() + + @classmethod + def select_active_by_kwargs( + cls, + project_code: str, + key: str, + kwargs_match: dict[str, str | list[str]], + ) -> list[Self]: + """Find active schedules whose ``kwargs`` JSONB matches the given (key, value) pairs. + + Values may be a single string or a list of strings (which becomes an ``IN`` filter). + """ + query = select(cls).where( + cls.project_code == project_code, + cls.key == key, + cls.active.is_(True), + ) + for k, v in kwargs_match.items(): + if isinstance(v, list): + if not v: + return [] + query = query.where(cls.kwargs[k].astext.in_([str(x) for x in v])) + else: + query = query.where(cls.kwargs[k].astext == str(v)) + return list(get_current_session().scalars(query).all()) def get_sample_triggering_timestamps(self, n=3) -> list[datetime]: schedule = Cron(cron_string=self.cron_expr).schedule(timezone_str=self.cron_tz) diff --git a/testgen/common/models/test_run.py b/testgen/common/models/test_run.py index 7653f355..c00c6c6b 100644 --- a/testgen/common/models/test_run.py +++ b/testgen/common/models/test_run.py @@ -218,6 +218,8 @@ def select_summary( table_group_id: str | None = None, test_suite_id: str | None = None, test_run_ids: list[str | UUID] | None = None, + job_execution_id: str | UUID | None = None, + statuses: list[JobStatus] | None = None, page: int = 1, page_size: int = 20, ) -> tuple[list[TestRunSummary], int]: @@ -225,9 +227,13 @@ def select_summary( (table_group_id and not is_uuid4(table_group_id)) or (test_suite_id and not is_uuid4(test_suite_id)) or (test_run_ids and not all(is_uuid4(run_id) for run_id in test_run_ids)) + or (job_execution_id and not is_uuid4(job_execution_id)) ): return [], 0 + # Pending JEs (no tr row) surface in project-scope queries — the LEFT JOIN to + # test_suites yields a NULL ts row that the ``ts.id IS NULL`` clause lets through — + # but not in suite/TG-scoped queries, since the WHERE filter requires ts to match. query = f""" WITH run_results AS ( SELECT test_run_id, @@ -282,6 +288,8 @@ def select_summary( {" AND ts.table_groups_id = :table_group_id" if table_group_id else ""} {" AND ts.id = :test_suite_id" if test_suite_id else ""} {" AND tr.id IN :test_run_ids" if test_run_ids else ""} + {" AND je.id = :job_execution_id" if job_execution_id else ""} + {" AND je.status IN :statuses" if statuses else ""} ORDER BY je.created_at DESC LIMIT :limit OFFSET :offset; """ @@ -290,6 +298,8 @@ def select_summary( "table_group_id": table_group_id, "test_suite_id": test_suite_id, "test_run_ids": tuple(test_run_ids or []), + "job_execution_id": str(job_execution_id) if job_execution_id else None, + "statuses": tuple(statuses) if statuses else (), "limit": page_size, "offset": (page - 1) * page_size, } diff --git a/testgen/mcp/prompts/workflows.py b/testgen/mcp/prompts/workflows.py index 4201bf75..bf90493c 100644 --- a/testgen/mcp/prompts/workflows.py +++ b/testgen/mcp/prompts/workflows.py @@ -7,7 +7,7 @@ def health_check() -> str: Please perform a data quality health check: 1. Call `get_data_inventory()` to get a complete overview of all projects, connections, table groups, and test suites. -2. For each project, call `get_recent_test_runs(...)` to get the latest test runs across all suites. +2. For each project, call `list_test_runs(...)` to get the latest test runs across all suites. 3. Summarize the overall health: - Which projects/suites are healthy (all tests passing)? - Which have failures or warnings? @@ -29,7 +29,7 @@ def investigate_failures(test_suite: str | None = None) -> str: Please investigate test failures and identify root causes:{suite_filter} 1. Call `get_data_inventory()` to understand the project structure. -2. Call `get_recent_test_runs(...)` to find the latest run per suite{f" for suite `{test_suite}`" if test_suite else ""}. +2. Call `list_test_runs(...)` to find the latest run per suite{f" for suite `{test_suite}`" if test_suite else ""}. 3. Call `get_failure_summary(job_execution_id='...')` to see failures grouped by test type. 4. For each failure category, call `get_test_type(test_type='...')` to understand what the test checks. 5. Call `list_test_results(test_suite_id='...', status='Failed')` to drill into the specific failing tests in the latest run. diff --git a/testgen/mcp/server.py b/testgen/mcp/server.py index 4d53450b..2b0539fb 100644 --- a/testgen/mcp/server.py +++ b/testgen/mcp/server.py @@ -148,7 +148,13 @@ def build_mcp_server( search_hygiene_issues, update_hygiene_issue, ) - from testgen.mcp.tools.profiling import get_table, list_column_profiles, list_profiling_summaries + from testgen.mcp.tools.profiling import ( + get_profiling_run, + get_table, + list_column_profiles, + list_profiling_runs, + list_profiling_summaries, + ) from testgen.mcp.tools.reference import ( get_test_type, glossary_resource, @@ -160,12 +166,12 @@ def build_mcp_server( from testgen.mcp.tools.test_results import ( get_failure_summary, get_failure_trend, - get_test_result_history, get_test_run_diff, + list_test_result_history, list_test_results, search_test_results, ) - from testgen.mcp.tools.test_runs import get_recent_test_runs + from testgen.mcp.tools.test_runs import get_test_run, list_test_runs if server_url is None: server_url = f"{api_base_url}/mcp" @@ -196,9 +202,10 @@ def safe_prompt(fn): safe_tool(list_projects) safe_tool(list_tables) safe_tool(list_test_suites) - safe_tool(get_recent_test_runs) + safe_tool(list_test_runs) + safe_tool(get_test_run) safe_tool(list_test_results) - safe_tool(get_test_result_history) + safe_tool(list_test_result_history) safe_tool(get_failure_summary) safe_tool(search_test_results) safe_tool(get_failure_trend) @@ -213,6 +220,8 @@ def safe_prompt(fn): safe_tool(get_table) safe_tool(list_column_profiles) safe_tool(list_profiling_summaries) + safe_tool(list_profiling_runs) + safe_tool(get_profiling_run) safe_tool(run_tests) safe_tool(run_profiling) safe_tool(cancel_test_run) diff --git a/testgen/mcp/tools/common.py b/testgen/mcp/tools/common.py index 55b7ff02..4ddb39a2 100644 --- a/testgen/mcp/tools/common.py +++ b/testgen/mcp/tools/common.py @@ -1,10 +1,12 @@ -from datetime import date +from datetime import date, datetime from enum import StrEnum from uuid import UUID from testgen.common.date_service import parse_since from testgen.common.enums import ImpactDimension, QualityDimension from testgen.common.models.hygiene_issue import Disposition, HygieneIssueType, IssueLikelihood, PiiRisk +from testgen.common.models.job_execution import JobStatus +from testgen.common.models.scheduler import JobSchedule from testgen.common.models.table_group import TableGroup from testgen.common.models.test_definition import TestType from testgen.common.models.test_result import TestResultStatus @@ -85,6 +87,51 @@ def parse_quality_dimension(value: str) -> QualityDimension: raise MCPUserError(f"Invalid quality_dimension `{value}`. Valid values: {valid}") from err +# Maps user-facing run-status labels to underlying ``JobStatus`` values. Transient states +# (Starting/Canceling) are excluded because they're sub-second and noisy as filters. +# ``Pending`` collapses PENDING+CLAIMED; ``Canceled`` collapses CANCEL_REQUESTED+CANCELED. +_RUN_STATUS_FILTER: dict[str, list[JobStatus]] = { + "Pending": [JobStatus.PENDING, JobStatus.CLAIMED], + "Running": [JobStatus.RUNNING], + "Completed": [JobStatus.COMPLETED], + "Canceled": [JobStatus.CANCEL_REQUESTED, JobStatus.CANCELED], + "Error": [JobStatus.ERROR], +} + + +def parse_run_status_filter(value: str) -> list[JobStatus]: + """Map a user-facing run status label (e.g. ``Pending``) to the underlying ``JobStatus`` values.""" + statuses = _RUN_STATUS_FILTER.get(value) + if statuses is None: + valid = ", ".join(_RUN_STATUS_FILTER.keys()) + raise MCPUserError(f"Invalid status `{value}`. Valid values: {valid}") + return statuses + + +def format_run_duration(started_at: datetime | None, completed_at: datetime | None) -> str | None: + """Render an elapsed duration as ``Xs`` / ``Xm Ys`` / ``Xh Ym``. Returns ``None`` if either bound is missing.""" + if not started_at or not completed_at: + return None + seconds = int((completed_at - started_at).total_seconds()) + if seconds < 60: + return f"{seconds}s" + if seconds < 3600: + return f"{seconds // 60}m {seconds % 60}s" + return f"{seconds // 3600}h {(seconds % 3600) // 60}m" + + +def next_scheduled_run( + job_key: str, kwargs_filter: dict[str, str | list[str]], project_code: str, +) -> datetime | None: + """Return the next firing of an active ``JobSchedule`` matching ``job_key`` and a kwargs + filter. When multiple schedules match, the soonest next-firing wins. + """ + schedules = JobSchedule.select_active_by_kwargs(project_code, job_key, kwargs_filter) + if not schedules: + return None + return min(s.get_sample_triggering_timestamps(1)[0] for s in schedules) + + def parse_disposition(value: str) -> Disposition: """Validate a user-facing disposition label and return the stored ``Disposition``. diff --git a/testgen/mcp/tools/execution.py b/testgen/mcp/tools/execution.py index b7313535..987cbb4d 100644 --- a/testgen/mcp/tools/execution.py +++ b/testgen/mcp/tools/execution.py @@ -17,7 +17,7 @@ @mcp_permission("edit") def run_tests(test_suite_id: str) -> str: """Submit a test run for a test suite. Returns immediately with a job_execution_id; - use ``get_recent_test_runs`` to track status. + use ``list_test_runs`` to track status. Args: test_suite_id: UUID of the test suite to run, e.g. from ``list_test_suites``. @@ -29,7 +29,7 @@ def run_tests(test_suite_id: str) -> str: source=JobSource.mcp, project_code=suite.project_code, ) - return _render_submission("Test run", suite.test_suite, "Test suite", job, "get_recent_test_runs") + return _render_submission("Test run", suite.test_suite, "Test suite", job, "list_test_runs") @with_database_session @@ -86,10 +86,10 @@ def cancel_test_run(job_execution_id: str) -> str: """Request cancellation of a queued or running test run. Args: - job_execution_id: UUID of a test run, e.g. from ``get_recent_test_runs``. + job_execution_id: UUID of a test run, e.g. from ``list_test_runs``. """ job = _resolve_job_execution(job_execution_id, JobKey.run_tests, "Test run") - return _render_cancel(job, "Test run", "get_recent_test_runs") + return _render_cancel(job, "Test run", "list_test_runs") @with_database_session diff --git a/testgen/mcp/tools/profiling.py b/testgen/mcp/tools/profiling.py index 9d293425..1d8bbb1e 100644 --- a/testgen/mcp/tools/profiling.py +++ b/testgen/mcp/tools/profiling.py @@ -3,7 +3,9 @@ from testgen.common.models import with_database_session from testgen.common.models.data_column import ColumnProfileSummary, DataColumnChars from testgen.common.models.data_table import DataTable -from testgen.common.models.profiling_run import ProfilingRun +from testgen.common.models.job_execution import JobExecution +from testgen.common.models.profiling_run import ProfilingRun, ProfilingRunSummary +from testgen.common.models.scheduler import RUN_PROFILE_JOB_KEY from testgen.common.models.table_group import TableGroup, TableGroupSummary from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.permissions import get_project_permissions, mcp_permission @@ -11,8 +13,13 @@ DocGroup, format_page_footer, format_page_info, + format_run_duration, + next_scheduled_run, + parse_run_status_filter, parse_uuid, resolve_table_group, + validate_limit, + validate_page, ) from testgen.mcp.tools.markdown import MdDoc from testgen.utils import friendly_score @@ -243,6 +250,189 @@ def _render_column_profile_row(c: ColumnProfileSummary) -> list: ] +@with_database_session +@mcp_permission("catalog") +def list_profiling_runs( + table_group_id: str, + status: str | None = None, + limit: int = 10, + page: int = 1, +) -> str: + """List profiling run history for a table group, including queued, in-progress, and failed runs. + Ordered by submission time descending. + + Args: + table_group_id: UUID of the table group, e.g. from `get_data_inventory`. + status: Optional run status filter. One of: Pending, Running, Completed, Canceled, Error. + limit: Page size (default 10, max 100). + page: Page number starting at 1 (default 1). + """ + validate_limit(limit, 100) + validate_page(page) + + statuses = parse_run_status_filter(status) if status else None + tg = resolve_table_group(table_group_id) + + summaries, total = ProfilingRun.select_summary( + project_code=tg.project_code, + table_group_id=tg.id, + statuses=statuses, + page=page, + page_size=limit, + ) + + # Queued/claimed JEs that don't yet have a profiling_runs row are invisible to TG-scoped + # joined-run queries. Surface them as a separate "Pending" section on page 1. + pending_jes: list[JobExecution] = [] + if page == 1: + pending_jes = JobExecution.select_active_by_kwargs( + project_code=tg.project_code, + job_key=RUN_PROFILE_JOB_KEY, + kwargs_match={"table_group_id": str(tg.id)}, + statuses=statuses, + ) + + doc = MdDoc() + scope = f" — status `{status}`" if status else "" + doc.heading(1, f"Profiling runs for `{tg.table_groups_name}`{scope}") + + next_run = next_scheduled_run( + RUN_PROFILE_JOB_KEY, {"table_group_id": str(tg.id)}, tg.project_code + ) + if next_run: + doc.field("Next scheduled run", next_run) + + if pending_jes: + doc.heading(2, f"Pending ({len(pending_jes)})") + for je in pending_jes: + _render_pending_profiling_je(doc, je, label=tg.table_groups_name) + + page_info = format_page_info(total, page, limit) + if page_info: + doc.text(page_info) + + if not summaries: + if page > 1: + doc.text(f"_No profiling runs on page {page} (total: {total})._") + elif not pending_jes: + doc.text("_No profiling runs found._") + return doc.render() + + for run in summaries: + _render_profiling_run_section(doc, run) + + footer = format_page_footer(total, page, limit) + if footer: + doc.text(footer) + + return doc.render() + + +@with_database_session +@mcp_permission("catalog") +def get_profiling_run(job_execution_id: str) -> str: + """Get a single profiling run with status, timing, totals, and per-table breakdown. Returns the + run regardless of state — including queued and in-progress runs without complete results yet. + The per-table breakdown is only available after the run completes. + + Args: + job_execution_id: UUID of a profiling run, e.g. from `list_profiling_runs` or + `list_profiling_summaries`. + """ + parse_uuid(job_execution_id, "job_execution_id") + perms = get_project_permissions() + + summaries, _ = ProfilingRun.select_summary(job_execution_id=job_execution_id, page_size=1) + summary = summaries[0] if summaries else None + if summary is None or summary.project_code not in perms.allowed_codes: + raise MCPResourceNotAccessible("Profiling run", job_execution_id) + + doc = MdDoc() + tg_label = summary.table_groups_name or "—" + doc.heading(1, f"Profiling run: {tg_label}") + doc.field("Job ID", summary.job_execution_id, code=True) + if summary.table_groups_name: + doc.field("Table group", summary.table_groups_name) + if summary.table_group_schema: + doc.field("Schema", summary.table_group_schema) + doc.field("Status", summary.status_label) + doc.field("Submitted", summary.created_at) + doc.field("Started", summary.started_at or "—") + doc.field("Ended", summary.completed_at or "In progress") + duration = format_run_duration(summary.started_at, summary.completed_at) + if duration: + doc.field("Duration", duration) + + has_totals = summary.table_ct or summary.column_ct or summary.record_ct or summary.anomaly_ct + if has_totals: + doc.field("Tables profiled", summary.table_ct or 0) + doc.field("Columns profiled", summary.column_ct or 0) + if summary.record_ct is not None: + doc.field("Records", summary.record_ct) + doc.field( + "Hygiene issues (confirmed)", + f"{(summary.anomalies_definite_ct or 0) + (summary.anomalies_likely_ct or 0) + (summary.anomalies_possible_ct or 0)} total " + f"— {summary.anomalies_definite_ct or 0} definite, " + f"{summary.anomalies_likely_ct or 0} likely, " + f"{summary.anomalies_possible_ct or 0} possible", + ) + if summary.dq_score_profiling is not None: + doc.field("Profiling Score", friendly_score(summary.dq_score_profiling)) + + if summary.profiling_run_id: + breakdown = ProfilingRun.select_table_breakdown(summary.profiling_run_id) + if breakdown: + doc.heading(2, "Per-table breakdown") + doc.table( + ["Schema", "Table", "Records", "Columns", "Hygiene issues"], + [ + [r.schema_name, r.table_name, r.record_ct, r.column_ct, r.anomaly_ct] + for r in breakdown + ], + code=[0, 1], + ) + + if summary.error_message: + doc.heading(2, "Error") + doc.text(summary.error_message) + + return doc.render() + + +def _render_pending_profiling_je(doc: MdDoc, je: JobExecution, label: str) -> None: + status_label = ProfilingRunSummary.STATUS_LABEL.get(je.status, je.status) + doc.heading(3, f"{label} — {status_label}") + doc.field("Job ID", je.id, code=True) + doc.field("Submitted", je.created_at) + doc.field("Started", je.started_at or "—") + doc.field("Ended", je.completed_at or "In progress") + + +def _render_profiling_run_section(doc: MdDoc, run: ProfilingRunSummary) -> None: + title = run.table_groups_name or run.profiling_run_id or run.job_execution_id + doc.heading(2, f"{title} — {run.status_label}") + doc.field("Job ID", run.job_execution_id, code=True) + doc.field("Submitted", run.created_at) + doc.field("Started", run.started_at or "—") + doc.field("Ended", run.completed_at or "In progress") + duration = format_run_duration(run.started_at, run.completed_at) + if duration: + doc.field("Duration", duration) + + if run.table_ct or run.column_ct: + doc.field("Tables profiled", run.table_ct or 0) + doc.field("Columns profiled", run.column_ct or 0) + if run.anomaly_ct is not None and ( + run.anomalies_definite_ct or run.anomalies_likely_ct or run.anomalies_possible_ct + ): + doc.field( + "Hygiene issues (confirmed)", + f"{(run.anomalies_definite_ct or 0) + (run.anomalies_likely_ct or 0) + (run.anomalies_possible_ct or 0)} total", + ) + if run.dq_score_profiling is not None: + doc.field("Profiling Score", friendly_score(run.dq_score_profiling)) + + def _render_table_group_summary(doc: MdDoc, s: TableGroupSummary) -> None: doc.heading(2, s.table_groups_name) if s.connection_name: diff --git a/testgen/mcp/tools/test_results.py b/testgen/mcp/tools/test_results.py index ec708a3a..dff4331b 100644 --- a/testgen/mcp/tools/test_results.py +++ b/testgen/mcp/tools/test_results.py @@ -42,7 +42,7 @@ def list_test_results( the latest completed run of that suite. Args: - job_execution_id: UUID of a test run, e.g. from ``get_recent_test_runs`` or + job_execution_id: UUID of a test run, e.g. from ``list_test_runs`` or ``list_test_suites``. test_suite_id: UUID of a test suite. Resolves to the latest completed test run for the suite. Mutually exclusive with ``job_execution_id``. @@ -158,7 +158,7 @@ def get_failure_summary( Args: project_code: Scope to a project the caller can view. Ignored if ``job_execution_id`` is set. test_suite_id: UUID of a test suite to scope the aggregation to. - job_execution_id: UUID of a test run, e.g. from ``get_recent_test_runs``, + job_execution_id: UUID of a test run, e.g. from ``list_test_runs``, to scope the summary to a single run. since: Include runs since this point in time — e.g. '7 days', '2 weeks', '2026-04-01'. group_by: Group failures by 'test_type', 'table', or 'column' (default: 'test_type'). @@ -263,7 +263,7 @@ def get_failure_summary( @with_database_session @mcp_permission("view") -def get_test_result_history( +def list_test_result_history( test_definition_id: str, limit: int = 20, page: int = 1, @@ -330,7 +330,7 @@ def search_test_results( """Search test results across multiple runs with flexible filters. To drill into a single run, use ``list_test_results``. For a single test's history, use - ``get_test_result_history``. + ``list_test_result_history``. Args: project_code: Scope to a project the caller can view. @@ -524,7 +524,7 @@ def get_test_run_diff(job_execution_id_a: str, job_execution_id_b: str) -> str: """Compare two test runs and report regressions, improvements, persistent failures, and added/removed tests. Args: - job_execution_id_a: UUID of the older (baseline) test run, e.g. from ``get_recent_test_runs``. + job_execution_id_a: UUID of the older (baseline) test run, e.g. from ``list_test_runs``. job_execution_id_b: UUID of the newer test run. """ uuid_a = parse_uuid(job_execution_id_a, "job_execution_id_a") @@ -561,7 +561,7 @@ def _accessible(run) -> bool: raise MCPUserError( "Both runs must belong to the same test suite to be comparable. " f"Run A is in suite `{run_a.test_suite_id}`, run B is in suite `{run_b.test_suite_id}`. " - "Use `get_recent_test_runs(test_suite=...)` to pick two runs of the same suite." + "Use `list_test_runs(test_suite=...)` to pick two runs of the same suite." ) diff = TestResult.diff_with_details(run_a.id, run_b.id) diff --git a/testgen/mcp/tools/test_runs.py b/testgen/mcp/tools/test_runs.py index 68f9ce7b..7ca60dcc 100644 --- a/testgen/mcp/tools/test_runs.py +++ b/testgen/mcp/tools/test_runs.py @@ -1,8 +1,24 @@ +from datetime import datetime + from testgen.common.models import with_database_session -from testgen.common.models.test_run import TestRun +from testgen.common.models.job_execution import JobExecution +from testgen.common.models.scheduler import RUN_TESTS_JOB_KEY +from testgen.common.models.test_run import TestRun, TestRunSummary from testgen.common.models.test_suite import TestSuite +from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.permissions import get_project_permissions, mcp_permission -from testgen.mcp.tools.common import DocGroup, validate_limit +from testgen.mcp.tools.common import ( + DocGroup, + format_page_footer, + format_page_info, + format_run_duration, + next_scheduled_run, + parse_run_status_filter, + parse_uuid, + resolve_table_group, + validate_limit, + validate_page, +) from testgen.mcp.tools.markdown import MdDoc _DOC_GROUP = DocGroup.INVESTIGATE @@ -10,76 +26,290 @@ @with_database_session @mcp_permission("view") -def get_recent_test_runs(project_code: str, test_suite: str | None = None, limit: int = 1) -> str: - """Get the latest test runs for each test suite in a project, optionally filtered by test suite name. +def list_test_runs( + project_code: str | None = None, + test_suite: str | None = None, + table_group_id: str | None = None, + status: str | None = None, + limit: int = 10, + page: int = 1, +) -> str: + """List test runs across a project, including queued and in-progress runs. Ordered by submission + time descending. Excludes monitor suites. Args: - project_code: The project code to query. - test_suite: Optional test suite name to filter by. - limit: Maximum runs per test suite (default 1, max 100). + project_code: Project code to query, e.g. from `list_projects`. Required unless + `table_group_id` is provided (which scopes to a single project). + test_suite: Optional test suite name to filter by (case-sensitive). + table_group_id: Optional UUID of a table group, e.g. from `get_data_inventory`. Returns + runs for any suite in the group. + status: Optional run status filter. One of: Pending, Running, Completed, Canceled, Error. + limit: Page size (default 10, max 100). + page: Page number starting at 1 (default 1). """ - if not project_code: - return "Missing required parameter `project_code`." validate_limit(limit, 100) + validate_page(page) - perms = get_project_permissions() - perms.verify_access(project_code, not_found=f"No completed test runs found in project `{project_code}`.") + statuses = parse_run_status_filter(status) if status else None + if not project_code and not table_group_id: + raise MCPUserError("Provide either `project_code` or `table_group_id`.") + + perms = get_project_permissions() test_suite_id = None + table_group = None + + if table_group_id: + table_group = resolve_table_group(table_group_id) + if project_code and project_code != table_group.project_code: + raise MCPUserError( + f"`project_code` `{project_code}` does not match the table group's project." + ) + project_code = table_group.project_code + else: + perms.verify_access( + project_code, + not_found=MCPResourceNotAccessible("Project", project_code), + ) + if test_suite: suites = TestSuite.select_minimal_where( TestSuite.project_code == project_code, TestSuite.test_suite == test_suite, + TestSuite.is_monitor.isnot(True), ) if not suites: - return f"Test suite `{test_suite}` not found in project `{project_code}`." + raise MCPResourceNotAccessible("Test suite", test_suite) test_suite_id = str(suites[0].id) - summaries, _ = TestRun.select_summary(project_code=project_code, test_suite_id=test_suite_id, page_size=1000) + summaries, total = TestRun.select_summary( + project_code=project_code, + table_group_id=str(table_group.id) if table_group else None, + test_suite_id=test_suite_id, + statuses=statuses, + page=page, + page_size=limit, + ) + + # Queued/claimed JEs that don't yet have a test_runs row are invisible to suite/TG-scoped + # joined-run queries. Surface them as a separate "Pending" section on page 1. + pending_jes: list[JobExecution] = [] + if page == 1 and (test_suite_id or table_group): + pending_jes = _select_pending_test_jes( + project_code=project_code, + test_suite_id=test_suite_id, + table_group_id=str(table_group.id) if table_group else None, + statuses=statuses, + ) + + scope_descriptor = _scope_descriptor(project_code, test_suite, table_group_id, status) + doc = MdDoc() + doc.heading(1, f"Test runs{scope_descriptor}") + + next_run = _next_test_run( + project_code=project_code, + test_suite_id=test_suite_id, + table_group_id=str(table_group.id) if table_group else None, + ) + if next_run: + doc.field("Next scheduled run", next_run) + + if pending_jes: + doc.heading(2, f"Pending ({len(pending_jes)})") + for je in pending_jes: + _render_pending_je(doc, je, label=test_suite or "Test run") + + page_info = format_page_info(total, page, limit) + if page_info: + doc.text(page_info) if not summaries: - scope = f" for suite `{test_suite}`" if test_suite else "" - return f"No completed test runs found in project `{project_code}`{scope}." - - # Take the first `limit` runs per suite (summaries are ordered by test_starttime DESC) - seen: dict[str, int] = {} - runs = [] - for s in summaries: - count = seen.get(s.test_suite, 0) - if count < limit: - runs.append(s) - seen[s.test_suite] = count + 1 + if page > 1: + doc.text(f"_No test runs on page {page} (total: {total})._") + elif not pending_jes: + doc.text("_No test runs found._") + return doc.render() + + for run in summaries: + _render_test_run_section(doc, run) + + footer = format_page_footer(total, page, limit) + if footer: + doc.text(footer) + + return doc.render() + + +@with_database_session +@mcp_permission("view") +def get_test_run(job_execution_id: str) -> str: + """Get a single test run with status, timing, result counts, and testing score. Returns the + run regardless of state — including queued and in-progress runs without complete results yet. + + Args: + job_execution_id: UUID of a test run, e.g. from `list_test_runs`. + """ + parse_uuid(job_execution_id, "job_execution_id") + perms = get_project_permissions() + + summaries, _ = TestRun.select_summary(job_execution_id=job_execution_id, page_size=1) + summary = summaries[0] if summaries else None + if summary is None or summary.project_code not in perms.allowed_codes: + raise MCPResourceNotAccessible("Test run", job_execution_id) doc = MdDoc() + suite_label = summary.test_suite or "—" + doc.heading(1, f"Test run: {suite_label}") + doc.field("Job ID", summary.job_execution_id, code=True) + doc.field("Test suite", suite_label) + if summary.table_groups_name: + doc.field("Table group", summary.table_groups_name) + doc.field("Project", summary.project_code) + doc.field("Status", summary.status_label) + doc.field("Submitted", summary.created_at) + doc.field("Started", summary.started_at or "—") + doc.field("Ended", summary.completed_at or "In progress") + duration = format_run_duration(summary.started_at, summary.completed_at) + if duration: + doc.field("Duration", duration) + + has_results = summary.test_ct or summary.passed_ct or summary.failed_ct or summary.warning_ct or summary.error_ct + if has_results: + passed = summary.passed_ct or 0 + failed = summary.failed_ct or 0 + warning = summary.warning_ct or 0 + errors = summary.error_ct or 0 + doc.field( + "Results", + f"{summary.test_ct or 0} tests — {passed} passed, {failed} failed, {warning} warnings, {errors} errors", + ) + if summary.dismissed_ct: + doc.field("Dismissed", summary.dismissed_ct) + if summary.dq_score_testing is not None: + doc.field("Testing Score", f"{summary.dq_score_testing:.1f}") + + if summary.error_message: + doc.heading(2, "Error") + doc.text(summary.error_message) + + return doc.render() + + +def _scope_descriptor( + project_code: str | None, + test_suite: str | None, + table_group_id: str | None, + status: str | None, +) -> str: + parts: list[str] = [] + if project_code: + parts.append(f"project `{project_code}`") if test_suite: - doc.heading(1, f"Recent Test Runs for `{project_code}` / `{test_suite}`") - else: - doc.heading(1, f"Recent Test Runs for `{project_code}`") - doc.text(f"Showing {len(runs)} run(s) ({limit} per suite).") + parts.append(f"suite `{test_suite}`") + if table_group_id: + parts.append(f"table group `{table_group_id}`") + if status: + parts.append(f"status `{status}`") + return f" — {', '.join(parts)}" if parts else "" - current_suite = None - for run in runs: - if run.test_suite != current_suite: - current_suite = run.test_suite - doc.heading(2, current_suite) - passed = run.passed_ct or 0 - failed = run.failed_ct or 0 - warning = run.warning_ct or 0 - errors = run.error_ct or 0 +def _next_test_run( + project_code: str | None, + test_suite_id: str | None, + table_group_id: str | None, +) -> datetime | None: + """Compute the next scheduled test run when scoped to a single suite or table group.""" + if not project_code: + return None + if test_suite_id: + return next_scheduled_run(RUN_TESTS_JOB_KEY, {"test_suite_id": test_suite_id}, project_code) + if table_group_id: + suite_ids = [ + str(s.id) + for s in TestSuite.select_minimal_where( + TestSuite.project_code == project_code, + TestSuite.table_groups_id == table_group_id, + TestSuite.is_monitor.isnot(True), + ) + ] + candidates = [ + next_scheduled_run(RUN_TESTS_JOB_KEY, {"test_suite_id": sid}, project_code) + for sid in suite_ids + ] + candidates = [c for c in candidates if c is not None] + return min(candidates) if candidates else None + return None - doc.heading(3, f"{run.created_at} — {run.status_label}") - doc.field("Test Run", run.job_execution_id, code=True) - doc.field("Started", run.created_at) - doc.field("Ended", run.completed_at or "In progress") - doc.field("Results", f"{run.test_ct or 0} tests — {passed} passed, {failed} failed, {warning} warnings, {errors} errors") - if run.dismissed_ct: - doc.field("Dismissed", run.dismissed_ct) +def _select_pending_test_jes( + *, + project_code: str, + test_suite_id: str | None, + table_group_id: str | None, + statuses, +) -> list[JobExecution]: + """Find queued/in-flight test-run JEs for a given suite or table group scope. For a + table-group scope, expands to the non-monitor suites in the group so monitor runs stay + excluded. + """ + if test_suite_id: + suite_ids: str | list[str] = test_suite_id + elif table_group_id: + suite_ids = [ + str(s.id) + for s in TestSuite.select_minimal_where( + TestSuite.project_code == project_code, + TestSuite.table_groups_id == table_group_id, + TestSuite.is_monitor.isnot(True), + ) + ] + if not suite_ids: + return [] + else: + return [] + return JobExecution.select_active_by_kwargs( + project_code=project_code, + job_key=RUN_TESTS_JOB_KEY, + kwargs_match={"test_suite_id": suite_ids}, + statuses=statuses, + ) - if run.dq_score_testing is not None: - doc.field("Testing Score", f"{run.dq_score_testing:.1f}") - doc.text("Use `list_test_results(job_execution_id='...')` for detailed results of a specific run.") +def _render_pending_je(doc: MdDoc, je: JobExecution, label: str) -> None: + status_label = TestRunSummary.STATUS_LABEL.get(je.status, je.status) + doc.heading(3, f"{label} — {status_label}") + doc.field("Job ID", je.id, code=True) + doc.field("Submitted", je.created_at) + doc.field("Started", je.started_at or "—") + doc.field("Ended", je.completed_at or "In progress") - return doc.render() + +def _render_test_run_section(doc: MdDoc, run: TestRunSummary) -> None: + title = run.test_suite or run.project_code + doc.heading(2, f"{title} — {run.status_label}") + doc.field("Job ID", run.job_execution_id, code=True) + if run.test_suite: + doc.field("Test suite", run.test_suite) + if run.table_groups_name: + doc.field("Table group", run.table_groups_name) + doc.field("Submitted", run.created_at) + doc.field("Started", run.started_at or "—") + doc.field("Ended", run.completed_at or "In progress") + duration = format_run_duration(run.started_at, run.completed_at) + if duration: + doc.field("Duration", duration) + + passed = run.passed_ct or 0 + failed = run.failed_ct or 0 + warning = run.warning_ct or 0 + errors = run.error_ct or 0 + if run.test_ct or passed or failed or warning or errors: + doc.field( + "Results", + f"{run.test_ct or 0} tests — {passed} passed, {failed} failed, {warning} warnings, {errors} errors", + ) + + if run.dismissed_ct: + doc.field("Dismissed", run.dismissed_ct) + if run.dq_score_testing is not None: + doc.field("Testing Score", f"{run.dq_score_testing:.1f}") diff --git a/tests/unit/mcp/test_tools_execution.py b/tests/unit/mcp/test_tools_execution.py index 79f7e99b..8b5ece21 100644 --- a/tests/unit/mcp/test_tools_execution.py +++ b/tests/unit/mcp/test_tools_execution.py @@ -65,7 +65,7 @@ def test_run_tests_submits_job(mock_suite_cls, mock_job_exec, db_session_mock): assert "Test run submitted for `Quality Suite`" in result assert str(submitted.id) in result assert "Pending" in result - assert "get_recent_test_runs" in result + assert "list_test_runs" in result def test_run_tests_invalid_uuid(db_session_mock): @@ -259,7 +259,7 @@ def fake_request_cancel(): assert "Test run cancellation requested" in result assert str(job_id) in result assert "cancel_requested" in result - assert "get_recent_test_runs" in result + assert "list_test_runs" in result def test_cancel_test_run_filters_by_job_key(db_session_mock): diff --git a/tests/unit/mcp/test_tools_profiling.py b/tests/unit/mcp/test_tools_profiling.py index e9773075..9415933b 100644 --- a/tests/unit/mcp/test_tools_profiling.py +++ b/tests/unit/mcp/test_tools_profiling.py @@ -479,3 +479,198 @@ def test_list_profiling_summaries_inaccessible_tg(mock_tg_cls, db_session_mock): from testgen.mcp.tools.profiling import list_profiling_summaries with pytest.raises(MCPResourceNotAccessible, match="Table group .* not found or not accessible"): list_profiling_summaries(table_group_id=str(uuid4())) + + +# ---------------------------------------------------------------------- +# list_profiling_runs +# ---------------------------------------------------------------------- + +from datetime import UTC, datetime + +from testgen.common.models.job_execution import JobStatus + +_RUN_CREATED = datetime(2026, 4, 1, 10, 0, 0, tzinfo=UTC) +_RUN_STARTED = datetime(2026, 4, 1, 10, 0, 5, tzinfo=UTC) +_RUN_COMPLETED = datetime(2026, 4, 1, 10, 1, 30, tzinfo=UTC) + + +def _mock_profiling_run(**overrides): + defaults = { + "job_execution_id": uuid4(), + "profiling_run_id": uuid4(), + "project_code": "demo", + "status": JobStatus.COMPLETED, + "status_label": "Completed", + "created_at": _RUN_CREATED, + "started_at": _RUN_STARTED, + "completed_at": _RUN_COMPLETED, + "error_message": None, + "table_groups_name": "demo-tg", + "table_group_schema": "demo", + "table_ct": 5, "column_ct": 30, "record_ct": 1000, + "anomaly_ct": 4, + "anomalies_definite_ct": 1, "anomalies_likely_ct": 1, + "anomalies_possible_ct": 2, "anomalies_dismissed_ct": 0, + "dq_score_profiling": 95.5, + } + defaults.update(overrides) + return MagicMock(**defaults) + + +@patch("testgen.mcp.tools.profiling.JobExecution") +@patch("testgen.mcp.tools.profiling.next_scheduled_run", return_value=None) +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_profiling_runs_default(mock_tg_cls, mock_run_cls, mock_next, mock_je, db_session_mock): + mock_je.select_active_by_kwargs.return_value = [] + tg = _mock_table_group() + tg.table_groups_name = "demo-tg" + mock_tg_cls.get.return_value = tg + mock_run_cls.select_summary.return_value = ([_mock_profiling_run()], 1) + + from testgen.mcp.tools.profiling import list_profiling_runs + result = list_profiling_runs(table_group_id=str(uuid4())) + + assert "Profiling runs for `demo-tg`" in result + assert "Completed" in result + call_kwargs = mock_run_cls.select_summary.call_args.kwargs + assert call_kwargs["statuses"] is None + + +@patch("testgen.mcp.tools.profiling.JobExecution") +@patch("testgen.mcp.tools.profiling.next_scheduled_run", return_value=None) +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_profiling_runs_status_filter(mock_tg_cls, mock_run_cls, mock_next, mock_je, db_session_mock): + mock_je.select_active_by_kwargs.return_value = [] + mock_tg_cls.get.return_value = _mock_table_group() + mock_run_cls.select_summary.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_profiling_runs + list_profiling_runs(table_group_id=str(uuid4()), status="Pending") + + call_kwargs = mock_run_cls.select_summary.call_args.kwargs + assert call_kwargs["statuses"] == [JobStatus.PENDING, JobStatus.CLAIMED] + + +@patch("testgen.mcp.tools.profiling.JobExecution") +@patch("testgen.mcp.tools.profiling.next_scheduled_run", return_value=_RUN_STARTED) +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_profiling_runs_shows_next_scheduled(mock_tg_cls, mock_run_cls, mock_next, mock_je, db_session_mock): + mock_je.select_active_by_kwargs.return_value = [] + mock_tg_cls.get.return_value = _mock_table_group() + mock_run_cls.select_summary.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_profiling_runs + result = list_profiling_runs(table_group_id=str(uuid4())) + + assert "Next scheduled run" in result + + +@patch("testgen.mcp.tools.profiling.next_scheduled_run", return_value=None) +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_profiling_runs_invalid_status(mock_tg_cls, mock_run_cls, mock_next, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + + from testgen.mcp.tools.profiling import list_profiling_runs + with pytest.raises(MCPUserError, match="Invalid status"): + list_profiling_runs(table_group_id=str(uuid4()), status="Bogus") + + +# ---------------------------------------------------------------------- +# get_profiling_run +# ---------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.profiling.ProfilingRun") +def test_get_profiling_run_returns_detail(mock_run_cls, db_session_mock): + summary = _mock_profiling_run() + mock_run_cls.select_summary.return_value = ([summary], 1) + mock_run = MagicMock(project_code="demo") + mock_run_cls.get_by_id_or_job.return_value = mock_run + mock_run_cls.select_table_breakdown.return_value = [ + MagicMock(schema_name="demo", table_name="orders", record_ct=1000, column_ct=5, anomaly_ct=2), + ] + + with patch("testgen.mcp.permissions._compute_project_permissions") as mock_compute: + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, permission="catalog", username="test_user", + ) + with patch("testgen.mcp.permissions.PluginHook") as mock_hook: + mock_hook.instance().rbac.get_roles_with_permission.return_value = ["role_a"] + + from testgen.mcp.tools.profiling import get_profiling_run + result = get_profiling_run(str(summary.job_execution_id)) + + assert "Profiling run: demo-tg" in result + assert "Completed" in result + assert "Per-table breakdown" in result + assert "orders" in result + + +@patch("testgen.mcp.tools.profiling.ProfilingRun") +def test_get_profiling_run_pending_no_breakdown(mock_run_cls, db_session_mock): + summary = _mock_profiling_run( + status=JobStatus.PENDING, status_label="Pending", + profiling_run_id=None, started_at=None, completed_at=None, + table_ct=None, column_ct=None, record_ct=None, anomaly_ct=None, + anomalies_definite_ct=None, anomalies_likely_ct=None, + anomalies_possible_ct=None, dq_score_profiling=None, + ) + mock_run_cls.select_summary.return_value = ([summary], 1) + mock_run_cls.get_by_id_or_job.return_value = MagicMock(project_code="demo") + + with patch("testgen.mcp.permissions._compute_project_permissions") as mock_compute: + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, permission="catalog", username="test_user", + ) + with patch("testgen.mcp.permissions.PluginHook") as mock_hook: + mock_hook.instance().rbac.get_roles_with_permission.return_value = ["role_a"] + + from testgen.mcp.tools.profiling import get_profiling_run + result = get_profiling_run(str(summary.job_execution_id)) + + assert "Pending" in result + assert "In progress" in result + assert "Per-table breakdown" not in result + + +@patch("testgen.mcp.tools.profiling.ProfilingRun") +def test_get_profiling_run_not_found(mock_run_cls, db_session_mock): + mock_run_cls.select_summary.return_value = ([], 0) + + with patch("testgen.mcp.permissions._compute_project_permissions") as mock_compute: + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, permission="catalog", username="test_user", + ) + with patch("testgen.mcp.permissions.PluginHook") as mock_hook: + mock_hook.instance().rbac.get_roles_with_permission.return_value = ["role_a"] + + from testgen.mcp.tools.profiling import get_profiling_run + with pytest.raises(MCPResourceNotAccessible): + get_profiling_run(str(uuid4())) + + +@patch("testgen.mcp.tools.profiling.ProfilingRun") +def test_get_profiling_run_inaccessible_project(mock_run_cls, db_session_mock): + summary = _mock_profiling_run(project_code="secret") + mock_run_cls.select_summary.return_value = ([summary], 1) + + with patch("testgen.mcp.permissions._compute_project_permissions") as mock_compute: + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, permission="catalog", username="test_user", + ) + with patch("testgen.mcp.permissions.PluginHook") as mock_hook: + mock_hook.instance().rbac.get_roles_with_permission.return_value = ["role_a"] + + from testgen.mcp.tools.profiling import get_profiling_run + with pytest.raises(MCPResourceNotAccessible): + get_profiling_run(str(summary.job_execution_id)) + + +def test_get_profiling_run_invalid_uuid(db_session_mock): + from testgen.mcp.tools.profiling import get_profiling_run + with pytest.raises(MCPUserError, match="not a valid UUID"): + get_profiling_run("not-a-uuid") diff --git a/tests/unit/mcp/test_tools_test_results.py b/tests/unit/mcp/test_tools_test_results.py index cadcb86c..f89c9d1f 100644 --- a/tests/unit/mcp/test_tools_test_results.py +++ b/tests/unit/mcp/test_tools_test_results.py @@ -487,7 +487,7 @@ def test_get_failure_summary_passes_project_codes( @patch("testgen.mcp.tools.test_results.TestType") @patch("testgen.mcp.tools.test_results.TestResult") -def test_get_test_result_history_basic(mock_result, mock_tt_cls, db_session_mock): +def test_list_test_result_history_basic(mock_result, mock_tt_cls, db_session_mock): def_id = str(uuid4()) r1 = MagicMock() r1.test_type = "Unique_Pct" @@ -512,9 +512,9 @@ def test_get_test_result_history_basic(mock_result, mock_tt_cls, db_session_mock tt.test_name_short = "Unique Percent" mock_tt_cls.select_where.return_value = [tt] - from testgen.mcp.tools.test_results import get_test_result_history + from testgen.mcp.tools.test_results import list_test_result_history - result = get_test_result_history(def_id) + result = list_test_result_history(def_id) assert "Unique Percent" in result assert "Unique_Pct" not in result @@ -526,26 +526,26 @@ def test_get_test_result_history_basic(mock_result, mock_tt_cls, db_session_mock @patch("testgen.mcp.tools.test_results.TestResult") -def test_get_test_result_history_empty(mock_result, db_session_mock): +def test_list_test_result_history_empty(mock_result, db_session_mock): mock_result.select_history.return_value = [] - from testgen.mcp.tools.test_results import get_test_result_history + from testgen.mcp.tools.test_results import list_test_result_history - result = get_test_result_history(str(uuid4())) + result = list_test_result_history(str(uuid4())) assert "No historical results" in result -def test_get_test_result_history_invalid_uuid(db_session_mock): - from testgen.mcp.tools.test_results import get_test_result_history +def test_list_test_result_history_invalid_uuid(db_session_mock): + from testgen.mcp.tools.test_results import list_test_result_history with pytest.raises(MCPUserError, match="not a valid UUID"): - get_test_result_history("bad-uuid") + list_test_result_history("bad-uuid") @patch("testgen.mcp.tools.test_results.TestResult") @patch("testgen.mcp.permissions._compute_project_permissions") -def test_get_test_result_history_passes_project_codes( +def test_list_test_result_history_passes_project_codes( mock_compute, mock_result, db_session_mock, ): mock_compute.return_value = ProjectPermissions( @@ -555,9 +555,9 @@ def test_get_test_result_history_passes_project_codes( ) mock_result.select_history.return_value = [] - from testgen.mcp.tools.test_results import get_test_result_history + from testgen.mcp.tools.test_results import list_test_result_history - get_test_result_history(str(uuid4())) + list_test_result_history(str(uuid4())) call_kwargs = mock_result.select_history.call_args.kwargs assert call_kwargs["project_codes"] == ["proj_a"] diff --git a/tests/unit/mcp/test_tools_test_runs.py b/tests/unit/mcp/test_tools_test_runs.py index c914dd25..3728ae27 100644 --- a/tests/unit/mcp/test_tools_test_runs.py +++ b/tests/unit/mcp/test_tools_test_runs.py @@ -1,188 +1,305 @@ +from datetime import UTC, datetime from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest -from testgen.mcp.exceptions import MCPPermissionDenied +from testgen.common.models.job_execution import JobStatus +from testgen.mcp.exceptions import MCPPermissionDenied, MCPResourceNotAccessible, MCPUserError from testgen.mcp.permissions import ProjectPermissions +_CREATED = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC) +_STARTED = datetime(2024, 1, 15, 10, 0, 0, tzinfo=UTC) +_COMPLETED = datetime(2024, 1, 15, 10, 5, 0, tzinfo=UTC) + def _make_run_summary(**overrides): defaults = { "test_run_id": uuid4(), "job_execution_id": uuid4(), - "test_suite": "Quality Suite", "project_name": "Demo", - "table_groups_name": "core_tables", "status": "completed", + "test_suite": "Quality Suite", "project_name": "Demo", "project_code": "demo", + "table_groups_name": "core_tables", "status": JobStatus.COMPLETED, "status_label": "Completed", - "created_at": "2024-01-15T10:00:00", - "started_at": "2024-01-15T10:00:00", "completed_at": "2024-01-15T10:05:00", + "created_at": _CREATED, "started_at": _STARTED, "completed_at": _COMPLETED, "test_ct": 50, "passed_ct": 45, "failed_ct": 3, "warning_ct": 2, "error_ct": 0, "log_ct": 0, "dismissed_ct": 0, "dq_score_testing": 92.5, + "error_message": None, } defaults.update(overrides) return MagicMock(**defaults) +@patch("testgen.mcp.tools.test_runs.next_scheduled_run", return_value=None) @patch("testgen.mcp.tools.test_runs.TestRun") @patch("testgen.mcp.tools.test_runs.TestSuite") -def test_get_recent_test_runs_default_limit(mock_suite, mock_run, db_session_mock): - """Default limit=1 returns one run per suite.""" - runs = [_make_run_summary(test_run_id=uuid4()) for _ in range(7)] +def test_list_test_runs_default(mock_suite, mock_run, mock_next, db_session_mock): + runs = [_make_run_summary() for _ in range(3)] mock_run.select_summary.return_value = (runs, len(runs)) - from testgen.mcp.tools.test_runs import get_recent_test_runs + from testgen.mcp.tools.test_runs import list_test_runs - result = get_recent_test_runs("demo") + result = list_test_runs(project_code="demo") - # All 7 runs have test_suite="Quality Suite", so only 1 should appear - assert "1 run(s)" in result + mock_run.select_summary.assert_called_once_with( + project_code="demo", + table_group_id=None, + test_suite_id=None, + statuses=None, + page=1, + page_size=10, + ) + assert "Test runs" in result + assert "demo" in result assert "Quality Suite" in result assert "92.5" in result - mock_run.select_summary.assert_called_once_with(project_code="demo", test_suite_id=None, page_size=1000) - - -@patch("testgen.mcp.tools.test_runs.TestRun") -@patch("testgen.mcp.tools.test_runs.TestSuite") -def test_get_recent_test_runs_custom_limit(mock_suite, mock_run, db_session_mock): - """Custom limit returns up to N runs per suite.""" - runs = [_make_run_summary() for _ in range(3)] - mock_run.select_summary.return_value = (runs, len(runs)) - - from testgen.mcp.tools.test_runs import get_recent_test_runs - - result = get_recent_test_runs("demo", limit=10) - - assert "3 run(s)" in result +@patch("testgen.mcp.tools.test_runs.next_scheduled_run", return_value=None) @patch("testgen.mcp.tools.test_runs.TestRun") @patch("testgen.mcp.tools.test_runs.TestSuite") -def test_get_recent_test_runs_per_suite_grouping(mock_suite, mock_run, db_session_mock): - """With multiple suites, returns limit runs per suite.""" - runs = [ - _make_run_summary(test_suite="Suite A", test_run_id=uuid4()), - _make_run_summary(test_suite="Suite A", test_run_id=uuid4()), - _make_run_summary(test_suite="Suite B", test_run_id=uuid4()), - _make_run_summary(test_suite="Suite B", test_run_id=uuid4()), - ] - mock_run.select_summary.return_value = (runs, len(runs)) +def test_list_test_runs_with_status_filter(mock_suite, mock_run, mock_next, db_session_mock): + mock_run.select_summary.return_value = ([], 0) - from testgen.mcp.tools.test_runs import get_recent_test_runs + from testgen.mcp.tools.test_runs import list_test_runs - result = get_recent_test_runs("demo") + list_test_runs(project_code="demo", status="Pending") - # limit=1 (default), so 1 per suite = 2 total - assert "2 run(s)" in result - assert "Suite A" in result - assert "Suite B" in result + call_kwargs = mock_run.select_summary.call_args.kwargs + assert call_kwargs["statuses"] == [JobStatus.PENDING, JobStatus.CLAIMED] +@patch("testgen.mcp.tools.test_runs.JobExecution") +@patch("testgen.mcp.tools.test_runs.next_scheduled_run", return_value=None) @patch("testgen.mcp.tools.test_runs.TestRun") @patch("testgen.mcp.tools.test_runs.TestSuite") -def test_get_recent_test_runs_with_suite_name(mock_suite, mock_run, db_session_mock): +def test_list_test_runs_with_suite_name(mock_suite, mock_run, mock_next, mock_je, db_session_mock): + mock_je.select_active_by_kwargs.return_value = [] suite_id = uuid4() - suite_minimal = MagicMock() - suite_minimal.id = suite_id + suite_minimal = MagicMock(id=suite_id) mock_suite.select_minimal_where.return_value = [suite_minimal] mock_run.select_summary.return_value = ([_make_run_summary(test_suite="My Suite")], 1) - from testgen.mcp.tools.test_runs import get_recent_test_runs + from testgen.mcp.tools.test_runs import list_test_runs - result = get_recent_test_runs("demo", test_suite="My Suite") + result = list_test_runs(project_code="demo", test_suite="My Suite") - mock_run.select_summary.assert_called_once_with(project_code="demo", test_suite_id=str(suite_id), page_size=1000) + call_kwargs = mock_run.select_summary.call_args.kwargs + assert call_kwargs["test_suite_id"] == str(suite_id) assert "My Suite" in result +@patch("testgen.mcp.tools.test_runs.next_scheduled_run", return_value=None) @patch("testgen.mcp.tools.test_runs.TestRun") @patch("testgen.mcp.tools.test_runs.TestSuite") -def test_get_recent_test_runs_suite_not_found(mock_suite, mock_run, db_session_mock): +def test_list_test_runs_suite_not_found(mock_suite, mock_run, mock_next, db_session_mock): mock_suite.select_minimal_where.return_value = [] - from testgen.mcp.tools.test_runs import get_recent_test_runs + from testgen.mcp.tools.test_runs import list_test_runs - result = get_recent_test_runs("demo", test_suite="Nonexistent") - - assert "not found" in result + with pytest.raises(MCPResourceNotAccessible): + list_test_runs(project_code="demo", test_suite="Nonexistent") mock_run.select_summary.assert_not_called() +@patch("testgen.mcp.tools.test_runs.next_scheduled_run", return_value=None) @patch("testgen.mcp.tools.test_runs.TestRun") @patch("testgen.mcp.tools.test_runs.TestSuite") -def test_get_recent_test_runs_no_runs(mock_suite, mock_run, db_session_mock): +def test_list_test_runs_empty(mock_suite, mock_run, mock_next, db_session_mock): mock_run.select_summary.return_value = ([], 0) - from testgen.mcp.tools.test_runs import get_recent_test_runs + from testgen.mcp.tools.test_runs import list_test_runs - result = get_recent_test_runs("demo") + result = list_test_runs(project_code="demo") - assert "No completed test runs" in result + assert "No test runs" in result +@patch("testgen.mcp.tools.test_runs.next_scheduled_run", return_value=None) @patch("testgen.mcp.tools.test_runs.TestRun") @patch("testgen.mcp.tools.test_runs.TestSuite") -def test_get_recent_test_runs_shows_failure_counts(mock_suite, mock_run, db_session_mock): - mock_run.select_summary.return_value = ([_make_run_summary(failed_ct=5, warning_ct=2)], 1) +def test_list_test_runs_includes_pending_run(mock_suite, mock_run, mock_next, db_session_mock): + pending = _make_run_summary( + status=JobStatus.PENDING, status_label="Pending", + started_at=None, completed_at=None, + test_ct=None, passed_ct=None, failed_ct=None, warning_ct=None, error_ct=None, + log_ct=None, dismissed_ct=None, dq_score_testing=None, + ) + mock_run.select_summary.return_value = ([pending], 1) - from testgen.mcp.tools.test_runs import get_recent_test_runs + from testgen.mcp.tools.test_runs import list_test_runs - result = get_recent_test_runs("demo") + result = list_test_runs(project_code="demo") - assert "5 failed" in result - assert "2 warnings" in result + assert "Pending" in result + assert "In progress" in result +@patch("testgen.mcp.tools.test_runs.JobExecution") +@patch("testgen.mcp.tools.test_runs.next_scheduled_run", return_value="2026-06-01T02:00:00") @patch("testgen.mcp.tools.test_runs.TestRun") @patch("testgen.mcp.tools.test_runs.TestSuite") -def test_get_recent_test_runs_outputs_job_execution_id(mock_suite, mock_run, db_session_mock): - """Output should contain job_execution_id, not test_run_id.""" - job_exec_id = uuid4() - run = _make_run_summary(job_execution_id=job_exec_id) - mock_run.select_summary.return_value = ([run], 1) +def test_list_test_runs_shows_next_scheduled(mock_suite, mock_run, mock_next, mock_je, db_session_mock): + mock_je.select_active_by_kwargs.return_value = [] + suite_id = uuid4() + mock_suite.select_minimal_where.return_value = [MagicMock(id=suite_id)] + mock_run.select_summary.return_value = ([], 0) + + from testgen.mcp.tools.test_runs import list_test_runs + + result = list_test_runs(project_code="demo", test_suite="Quality") + + assert "Next scheduled run" in result + + +@patch("testgen.mcp.tools.test_runs.JobExecution") +@patch("testgen.mcp.tools.test_runs.next_scheduled_run", return_value=None) +@patch("testgen.mcp.tools.test_runs.TestRun") +@patch("testgen.mcp.tools.test_runs.TestSuite") +def test_list_test_runs_renders_pending_section( + mock_suite, mock_run, mock_next, mock_je, db_session_mock, +): + """When scoped by suite, pending JEs are surfaced in a separate section.""" + suite_id = uuid4() + mock_suite.select_minimal_where.return_value = [MagicMock(id=suite_id)] + mock_run.select_summary.return_value = ([], 0) + pending_je = MagicMock( + id=uuid4(), status=JobStatus.PENDING, + created_at=_CREATED, started_at=None, completed_at=None, + ) + mock_je.select_active_by_kwargs.return_value = [pending_je] + + from testgen.mcp.tools.test_runs import list_test_runs + + result = list_test_runs(project_code="demo", test_suite="Quality") - from testgen.mcp.tools.test_runs import get_recent_test_runs + assert "Pending (1)" in result + assert "In progress" in result + mock_je.select_active_by_kwargs.assert_called_once() - result = get_recent_test_runs("demo") - assert str(job_exec_id) in result - assert "job_execution_id" in result +def test_list_test_runs_invalid_status(db_session_mock): + from testgen.mcp.tools.test_runs import list_test_runs + with pytest.raises(MCPUserError, match="Invalid status"): + list_test_runs(project_code="demo", status="Bogus") -def test_get_recent_test_runs_empty_project_code(db_session_mock): - from testgen.mcp.tools.test_runs import get_recent_test_runs - result = get_recent_test_runs("") +def test_list_test_runs_requires_project_or_table_group(db_session_mock): + from testgen.mcp.tools.test_runs import list_test_runs - assert "Missing required parameter" in result - assert "project_code" in result + with pytest.raises(MCPUserError, match="Provide either"): + list_test_runs() @patch("testgen.mcp.permissions._compute_project_permissions") -def test_get_recent_test_runs_raises_not_found_for_inaccessible_project( - mock_compute, db_session_mock, -): +def test_list_test_runs_raises_not_found_for_inaccessible_project(mock_compute, db_session_mock): mock_compute.return_value = ProjectPermissions( memberships={"other_project": "role_a"}, permission="view", username="test_user", ) - from testgen.mcp.tools.test_runs import get_recent_test_runs + from testgen.mcp.tools.test_runs import list_test_runs - with pytest.raises(MCPPermissionDenied, match="No completed test runs found in project `secret_project`"): - get_recent_test_runs("secret_project") + with pytest.raises(MCPPermissionDenied): + list_test_runs(project_code="secret_project") -@patch("testgen.mcp.permissions._compute_project_permissions") -def test_get_recent_test_runs_raises_denial_for_insufficient_permission( - mock_compute, db_session_mock, -): - mock_compute.return_value = ProjectPermissions( - memberships={"other_project": "role_a", "secret_project": "role_c"}, - permission="view", - username="test_user", +# ---------------------------------------------------------------------- +# get_test_run +# ---------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.test_runs.TestRun") +def test_get_test_run_returns_detail(mock_run, db_session_mock): + summary = _make_run_summary(project_code="demo") + mock_run.select_summary.return_value = ([summary], 1) + + with patch("testgen.mcp.permissions._compute_project_permissions") as mock_compute: + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, + permission="view", + username="test_user", + ) + with patch( + "testgen.mcp.permissions.PluginHook" + ) as mock_hook: + mock_hook.instance().rbac.get_roles_with_permission.return_value = ["role_a"] + from testgen.mcp.tools.test_runs import get_test_run + + result = get_test_run(str(summary.job_execution_id)) + + assert "Quality Suite" in result + assert "Completed" in result + assert "92.5" in result + + +@patch("testgen.mcp.tools.test_runs.TestRun") +def test_get_test_run_pending_no_results(mock_run, db_session_mock): + summary = _make_run_summary( + project_code="demo", + status=JobStatus.PENDING, status_label="Pending", + started_at=None, completed_at=None, + test_ct=None, passed_ct=None, failed_ct=None, warning_ct=None, error_ct=None, + log_ct=None, dismissed_ct=None, dq_score_testing=None, ) + mock_run.select_summary.return_value = ([summary], 1) + + with patch("testgen.mcp.permissions._compute_project_permissions") as mock_compute: + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, + permission="view", + username="test_user", + ) + with patch("testgen.mcp.permissions.PluginHook") as mock_hook: + mock_hook.instance().rbac.get_roles_with_permission.return_value = ["role_a"] + from testgen.mcp.tools.test_runs import get_test_run + + result = get_test_run(str(summary.job_execution_id)) + + assert "Pending" in result + assert "In progress" in result + assert "Results" not in result + + +@patch("testgen.mcp.tools.test_runs.TestRun") +def test_get_test_run_not_found(mock_run, db_session_mock): + mock_run.select_summary.return_value = ([], 0) + + with patch("testgen.mcp.permissions._compute_project_permissions") as mock_compute: + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, + permission="view", + username="test_user", + ) + with patch("testgen.mcp.permissions.PluginHook") as mock_hook: + mock_hook.instance().rbac.get_roles_with_permission.return_value = ["role_a"] + from testgen.mcp.tools.test_runs import get_test_run + + with pytest.raises(MCPResourceNotAccessible): + get_test_run(str(uuid4())) + + +@patch("testgen.mcp.tools.test_runs.TestRun") +def test_get_test_run_inaccessible_project(mock_run, db_session_mock): + summary = _make_run_summary(project_code="secret") + mock_run.select_summary.return_value = ([summary], 1) + + with patch("testgen.mcp.permissions._compute_project_permissions") as mock_compute: + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, + permission="view", + username="test_user", + ) + with patch("testgen.mcp.permissions.PluginHook") as mock_hook: + mock_hook.instance().rbac.get_roles_with_permission.return_value = ["role_a"] + from testgen.mcp.tools.test_runs import get_test_run + + with pytest.raises(MCPResourceNotAccessible): + get_test_run(str(summary.job_execution_id)) + - from testgen.mcp.tools.test_runs import get_recent_test_runs +def test_get_test_run_invalid_uuid(db_session_mock): + from testgen.mcp.tools.test_runs import get_test_run - with pytest.raises(MCPPermissionDenied, match="necessary permission"): - get_recent_test_runs("secret_project") + with pytest.raises(MCPUserError, match="not a valid UUID"): + get_test_run("not-a-uuid") From 6b2c390394e111ea62f7e882e8c75adc0616a1c3 Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Mon, 11 May 2026 14:03:25 -0400 Subject: [PATCH 05/58] feat(mcp): add test definition CRUD tools (TG-1054) Add create_test, update_test, validate_custom_test, and bulk_update_tests MCP tools, gated on the edit permission. Consolidate validation onto TestDefinition with editable_fields() and validate() methods, enforcing a whitelist at the MCP boundary so extra_params cannot override identity or internal columns. Co-Authored-By: Claude Opus 4.7 (1M context) --- testgen/common/models/test_definition.py | 87 +++ testgen/mcp/server.py | 15 +- testgen/mcp/tools/common.py | 28 +- testgen/mcp/tools/test_definitions.py | 388 +++++++++++- .../common/models/test_test_definition.py | 207 ++++++ tests/unit/mcp/test_tools_test_definitions.py | 587 ++++++++++++++++++ 6 files changed, 1277 insertions(+), 35 deletions(-) create mode 100644 tests/unit/common/models/test_test_definition.py diff --git a/testgen/common/models/test_definition.py b/testgen/common/models/test_definition.py index 8740203b..50cd78ef 100644 --- a/testgen/common/models/test_definition.py +++ b/testgen/common/models/test_definition.py @@ -1,6 +1,7 @@ from collections.abc import Iterable from dataclasses import dataclass from datetime import datetime +from enum import StrEnum from itertools import zip_longest from typing import ClassVar, Literal from uuid import UUID, uuid4 @@ -35,6 +36,24 @@ TestRunStatus = Literal["Running", "Complete", "Error", "Cancelled"] +class Severity(StrEnum): + FAIL = "Fail" + WARNING = "Warning" + + +class InvalidTestDefinitionFields(ValueError): + """Aggregated field-level validation errors. ``errors``: ``dict[field_name, reason]``.""" + + def __init__(self, errors: dict[str, str]) -> None: + self.errors = errors + super().__init__("; ".join(f"{k}: {v}" for k, v in errors.items())) + + +def _is_blank(value: object) -> bool: + # NullIfEmptyString columns turn ``""`` into NULL on write — treat both as cleared. + return value is None or value == "" + + class ParamFieldsMixin: """Parsed access to default_parm_columns/prompts/help metadata. @@ -204,6 +223,28 @@ def select_summary_where(cls, *clauses) -> Iterable[TestTypeSummary]: return [TestTypeSummary(**row) for row in results] +def _required_fields_for(test_type: TestType) -> set[str]: + """Fields that must be present and non-empty for the given test type. + + - Column-scoped tests implicitly require ``column_name``. + - Test types with ``custom_query`` in ``param_columns`` require ``custom_query``. + - ``default_parm_required`` is a CSV of ``Y``/``N`` aligned with ``default_parm_columns``; + positions marked ``Y`` are required. + """ + required: set[str] = set() + if test_type.test_scope == "column": + required.add("column_name") + if "custom_query" in test_type.param_columns: + required.add("custom_query") + if test_type.default_parm_required and test_type.default_parm_columns: + flags = [v.strip().upper() for v in test_type.default_parm_required.split(",")] + columns = [c.strip() for c in test_type.default_parm_columns.split(",")] + for col, flag in zip(columns, flags, strict=False): + if flag == "Y": + required.add(col) + return required + + class TestDefinition(Entity): __tablename__ = "test_definitions" @@ -397,6 +438,52 @@ def list_for_suite( _yn_columns: ClassVar = {"test_active", "lock_refresh"} + # Fields editable on every test type regardless of param_columns. + EDITABLE_BASE_FIELDS: ClassVar[frozenset[str]] = frozenset({ + "test_active", "severity", "lock_refresh", "flagged", "test_description", + }) + + def editable_fields(self, test_type: TestType) -> set[str]: + """Fields a caller may set or change on this test definition under the given test type.""" + return self.EDITABLE_BASE_FIELDS | test_type.param_columns + + def validate(self, test_type: TestType) -> None: + """Validate the current state against the given test type. + + Raises :class:`InvalidTestDefinitionFields` with every offending field + and reason — callers see all problems at once. + """ + errors: dict[str, str] = {} + + if self.severity: + try: + Severity(self.severity) + except ValueError: + errors["severity"] = ( + f"must be `{Severity.FAIL.value}` or `{Severity.WARNING.value}` " + f"(got `{self.severity}`)" + ) + + # column_name applies to column-scoped tests (the column under test) and + # custom-scoped tests (a "Test Focus" label). Other scopes don't use it. + if test_type.test_scope not in ("column", "custom") and not _is_blank(self.column_name): + errors["column_name"] = ( + f"test type `{test_type.test_type}` has scope `{test_type.test_scope}`; " + f"column_name does not apply to this scope" + ) + + if not _is_blank(self.custom_query) and "custom_query" not in test_type.param_columns: + errors["custom_query"] = ( + f"test type `{test_type.test_type}` does not accept a custom query" + ) + + for required in _required_fields_for(test_type): + if _is_blank(getattr(self, required, None)): + errors[required] = f"required for test type `{test_type.test_type}`" + + if errors: + raise InvalidTestDefinitionFields(errors) + @classmethod def set_status_attribute( cls, diff --git a/testgen/mcp/server.py b/testgen/mcp/server.py index 2b0539fb..d9d00613 100644 --- a/testgen/mcp/server.py +++ b/testgen/mcp/server.py @@ -162,7 +162,16 @@ def build_mcp_server( test_types_resource, ) from testgen.mcp.tools.source_data import get_source_data, get_source_data_query - from testgen.mcp.tools.test_definitions import get_test, list_test_notes, list_test_types, list_tests + from testgen.mcp.tools.test_definitions import ( + bulk_update_tests, + create_test, + get_test, + list_test_notes, + list_test_types, + list_tests, + update_test, + validate_custom_test, + ) from testgen.mcp.tools.test_results import ( get_failure_summary, get_failure_trend, @@ -227,6 +236,10 @@ def safe_prompt(fn): safe_tool(cancel_test_run) safe_tool(cancel_profiling_run) safe_tool(generate_tests) + safe_tool(create_test) + safe_tool(update_test) + safe_tool(validate_custom_test) + safe_tool(bulk_update_tests) safe_tool(list_hygiene_issues) safe_tool(get_hygiene_issue) safe_tool(search_hygiene_issues) diff --git a/testgen/mcp/tools/common.py b/testgen/mcp/tools/common.py index 4ddb39a2..adf71c05 100644 --- a/testgen/mcp/tools/common.py +++ b/testgen/mcp/tools/common.py @@ -2,13 +2,16 @@ from enum import StrEnum from uuid import UUID +from sqlalchemy import select + from testgen.common.date_service import parse_since from testgen.common.enums import ImpactDimension, QualityDimension +from testgen.common.models import get_current_session from testgen.common.models.hygiene_issue import Disposition, HygieneIssueType, IssueLikelihood, PiiRisk from testgen.common.models.job_execution import JobStatus from testgen.common.models.scheduler import JobSchedule from testgen.common.models.table_group import TableGroup -from testgen.common.models.test_definition import TestType +from testgen.common.models.test_definition import TestDefinition, TestType from testgen.common.models.test_result import TestResultStatus from testgen.common.models.test_suite import TestSuite from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError @@ -250,3 +253,26 @@ def resolve_test_suite(test_suite_id: str) -> TestSuite: if suite is None: raise MCPResourceNotAccessible("Test suite", test_suite_id) return suite + + +def resolve_test_definition(test_definition_id: str) -> TestDefinition: + """Resolve a test definition ID to the live ORM model, collapsing missing-or-inaccessible. + + Filters monitor suites and project access. Returns the ORM ``TestDefinition`` + (not ``TestDefinitionSummary``) so the row can be mutated and saved. + """ + td_uuid = parse_uuid(test_definition_id, "test_definition_id") + perms = get_project_permissions() + query = ( + select(TestDefinition) + .join(TestSuite, TestDefinition.test_suite_id == TestSuite.id) + .where( + TestDefinition.id == td_uuid, + TestSuite.is_monitor.isnot(True), + TestSuite.project_code.in_(perms.allowed_codes), + ) + ) + td = get_current_session().scalars(query).first() + if td is None: + raise MCPResourceNotAccessible("Test definition", test_definition_id) + return td diff --git a/testgen/mcp/tools/test_definitions.py b/testgen/mcp/tools/test_definitions.py index 6d28e3a7..f764999a 100644 --- a/testgen/mcp/tools/test_definitions.py +++ b/testgen/mcp/tools/test_definitions.py @@ -1,6 +1,20 @@ +from datetime import UTC, datetime +from enum import StrEnum +from typing import NoReturn + +from sqlalchemy import update + from testgen.common.enums import ImpactDimension, QualityDimension -from testgen.common.models import with_database_session -from testgen.common.models.test_definition import TestDefinition, TestDefinitionNote, TestDefinitionSummary, TestType +from testgen.common.models import get_current_session, with_database_session +from testgen.common.models.connection import Connection +from testgen.common.models.table_group import TableGroup +from testgen.common.models.test_definition import ( + InvalidTestDefinitionFields, + TestDefinition, + TestDefinitionNote, + TestDefinitionSummary, + TestType, +) from testgen.common.models.test_result import TestResult from testgen.mcp.exceptions import MCPUserError from testgen.mcp.permissions import get_project_permissions, mcp_permission @@ -11,17 +25,25 @@ parse_impact_dimension, parse_quality_dimension, parse_uuid, + resolve_test_definition, + resolve_test_suite, resolve_test_type, validate_limit, validate_page, ) from testgen.mcp.tools.markdown import MdDoc +from testgen.ui.services.database_service import fetch_from_target_db _DOC_GROUP = DocGroup.DISCOVER _VALID_SCOPES = {"column", "table", "referential", "custom"} +class BulkAction(StrEnum): + ENABLE = "enable" + DISABLE = "disable" + + @with_database_session @mcp_permission("view") def list_tests( @@ -118,11 +140,41 @@ def get_test(test_definition_id: str) -> str: if td is None: return f"Test definition `{test_definition_id}` not found." - test_name = td.display_name - doc = MdDoc() + _append_td_summary(doc, td) + + # Last result + results = TestResult.select_history( + test_definition_id=def_uuid, + project_codes=perms.allowed_codes, + limit=1, + ) + doc.heading(2, "Last Result") + if results: + r = results[0] + doc.field("Date", r.test_time) + doc.field("Status", r.status.value if r.status else None) + if r.message: + doc.field("Message", r.message) + else: + doc.text("_No results recorded for this test definition._") + + # Description + description = td.test_description or td.default_test_description + if description: + doc.heading(2, "Description") + doc.text(description) + if td.usage_notes: + doc.heading(2, "Usage Notes") + doc.text(td.usage_notes) + + return doc.render() + + +def _append_td_summary(doc: MdDoc, td: TestDefinitionSummary) -> None: + """Render the identity, configuration, parameters, custom-SQL, and reference-match sections of a test definition.""" + test_name = td.display_name - # Header if td.column_name: doc.heading(1, f"{test_name} on `{td.column_name}` in `{td.table_name}`") else: @@ -158,7 +210,7 @@ def get_test(test_definition_id: str) -> str: doc.field("Export to Observability", "Yes" if td.export_to_observability else "No") # Review status - notes = TestDefinitionNote.get_notes(def_uuid) + notes = TestDefinitionNote.get_notes(td.id) flag_str = "Flagged" if td.flagged else "Not Flagged" note_str = f"{len(notes)} Notes" if notes else "No Notes" doc.field("Review", f"{flag_str}, {note_str}") @@ -185,33 +237,6 @@ def get_test(test_definition_id: str) -> str: # Reference match (only fields listed in param_columns) _append_match_section(doc, td) - # Last result - results = TestResult.select_history( - test_definition_id=def_uuid, - project_codes=perms.allowed_codes, - limit=1, - ) - doc.heading(2, "Last Result") - if results: - r = results[0] - doc.field("Date", r.test_time) - doc.field("Status", r.status.value if r.status else None) - if r.message: - doc.field("Message", r.message) - else: - doc.text("_No results recorded for this test definition._") - - # Description - description = td.test_description or td.default_test_description - if description: - doc.heading(2, "Description") - doc.text(description) - if td.usage_notes: - doc.heading(2, "Usage Notes") - doc.text(td.usage_notes) - - return doc.render() - @with_database_session @mcp_permission("view") @@ -349,3 +374,300 @@ def list_test_types( ) return doc.render() + + +# --------------------------------------------------------------------------- +# Write tools (create / update / validate / bulk-update) +# +# All gated on ``edit`` permission. Atomic semantics on ``update_test`` — +# validation aggregates every field error before raising, so the LLM sees the +# full set in one response and the DB is never touched on a partial-error path. +# --------------------------------------------------------------------------- + + +def _raise_validation_errors(err: InvalidTestDefinitionFields, header: str) -> NoReturn: + """Convert aggregated validation errors into a user-facing ``MCPUserError``.""" + bullets = "\n".join(f"- `{field}`: {reason}" for field, reason in err.errors.items()) + raise MCPUserError(f"{header}\n\n{bullets}") from err + + +@with_database_session +@mcp_permission("edit") +def create_test( + test_suite_id: str, + test_type: str, + table_name: str, + column_name: str | None = None, + threshold_value: str | None = None, + baseline_value: str | None = None, + severity: str | None = None, + custom_query: str | None = None, + extra_params: dict | None = None, +) -> str: + """Create a test in a test suite. + + Args: + test_suite_id: UUID of the test suite. + test_type: Test type name, e.g. ``Alpha Truncation`` or ``Custom Test``. + table_name: Target table name. Case-sensitive. + column_name: Required for column-scoped test types. + threshold_value: Test threshold. + baseline_value: Baseline reference. + severity: ``Fail`` or ``Warning``. Omit to inherit the test type default. + custom_query: SQL for tests that accept a custom query. + extra_params: Additional test-type-specific parameters (e.g. ``window_days``, + ``match_column_names``, ``lower_tolerance``). Use ``list_test_types`` or + ``get_test`` on a similar test to discover supported names. + """ + suite = resolve_test_suite(test_suite_id) + tt_code = resolve_test_type(test_type) + tt = TestType.get(tt_code) + if tt is None: # resolve_test_type already raised if the short name is unknown + raise MCPUserError(f"Unknown test type: `{test_type}`.") + + table_group = TableGroup.get(suite.table_groups_id) + if table_group is None: + raise MCPUserError("Test suite is not associated with a table group.") + + td = TestDefinition( + test_suite_id=suite.id, + table_groups_id=table_group.id, + test_type=tt_code, + schema_name=table_group.table_group_schema, + table_name=table_name, + test_active=True, + lock_refresh=False, + last_manual_update=datetime.now(UTC), + ) + explicit = { + "column_name": column_name, + "threshold_value": threshold_value, + "baseline_value": baseline_value, + "severity": severity, + "custom_query": custom_query, + } + for key, value in explicit.items(): + if value is not None: + setattr(td, key, value) + + if extra_params: + accepted = td.editable_fields(tt) + rejected = sorted(set(extra_params) - accepted) + if rejected: + raise MCPUserError( + f"These `extra_params` keys are not editable for test type `{tt_code}`: " + f"{', '.join(rejected)}." + ) + conflicts = sorted(set(extra_params) & {k for k, v in explicit.items() if v is not None}) + if conflicts: + raise MCPUserError( + f"These fields were set both as named arguments and in `extra_params`: " + f"{', '.join(conflicts)}. Pass each value only once." + ) + for key, value in extra_params.items(): + setattr(td, key, value) + + try: + td.validate(tt) + except InvalidTestDefinitionFields as e: + _raise_validation_errors(e, "Test definition creation rejected. No changes saved.") + + td.save() + + # The joined test-type metadata (param_fields, default_severity, dq_dimension, ...) + # is only present on the Summary dataclass, so re-fetch for rendering. + perms = get_project_permissions() + summary = TestDefinition.get_for_project(td.id, perms.allowed_codes) + + doc = MdDoc() + doc.text(f"**Created** in suite `{suite.test_suite}`.") + _append_td_summary(doc, summary) + return doc.render() + + +@with_database_session +@mcp_permission("edit") +def update_test(test_definition_id: str, fields: dict) -> str: + """Update fields on an existing test. Atomic — no partial save. + + Args: + test_definition_id: UUID of the test definition. + fields: Mapping of field name to new value. Accepts the test type's parameter + columns (use ``get_test`` to see the current values and supported fields) + plus ``test_active``, ``severity``, ``lock_refresh``, ``flagged``. + """ + td = resolve_test_definition(test_definition_id) + tt = TestType.get(td.test_type) + if tt is None: + raise MCPUserError(f"Test type `{td.test_type}` not found for this test definition.") + + if not fields: + raise MCPUserError("No fields supplied to update.") + + accepted = td.editable_fields(tt) + rejected = sorted(set(fields) - accepted) + if rejected: + bullets = "\n".join( + f"- `{key}`: not editable for test type `{tt.test_type}`" for key in rejected + ) + raise MCPUserError(f"Update rejected. No changes saved.\n\n{bullets}") + + before: dict = {key: getattr(td, key, None) for key in fields} + for key, value in fields.items(): + setattr(td, key, value) + td.last_manual_update = datetime.now(UTC) + + try: + td.validate(tt) + except InvalidTestDefinitionFields as e: + _raise_validation_errors(e, "Update rejected. No changes saved.") + + td.save() + + doc = MdDoc() + doc.heading(1, f"Test definition `{td.id}` updated") + rows = [[key, _format_diff(before[key]), _format_diff(fields[key])] for key in fields] + doc.table(["Field", "Before", "After"], rows, code=[0]) + doc.text(f"{len(fields)} field(s) changed.") + return doc.render() + + +def _format_diff(value: object) -> str | None: + """Render a before/after cell, normalizing empty strings to ``None`` (NullIfEmptyString).""" + if value is None or value == "": + return None + if isinstance(value, bool): + return "Yes" if value else "No" + return str(value) + + +@with_database_session +@mcp_permission("edit") +def validate_custom_test(test_suite_id: str, custom_sql: str) -> str: + """Dry-run a custom test SQL query against the test suite's parent connection. + + Args: + test_suite_id: UUID of the test suite whose connection the SQL runs against. + custom_sql: SQL query to dry-run. + """ + suite = resolve_test_suite(test_suite_id) + connection = Connection.get_by_table_group(suite.table_groups_id) + if connection is None: + raise MCPUserError("No connection configured for this test suite's table group.") + + perms = get_project_permissions() + can_view_pii = suite.project_code in perms.codes_allowed_to("view_pii") + + doc = MdDoc() + doc.heading(1, "Custom test dry-run") + + try: + rows = fetch_from_target_db(connection, custom_sql) + except Exception as e: # broad catch: the DB error message IS the user-facing signal + doc.text(f"**SQL did not execute.** Query was not committed against `{connection.connection_name}`.") + message = str(e.args[0]) if e.args else str(e) + doc.text("**Error:**") + doc.code_block(message) + return doc.render() + + row_count = len(rows) + flavor = connection.sql_flavor_code or connection.sql_flavor or "target database" + doc.text( + f"**SQL ran successfully** against `{connection.connection_name}` ({flavor})." + ) + + if row_count == 0: + doc.text("**Would pass:** ✓ — query returned 0 error rows.") + doc.text( + "_If saved as a CUSTOM test, this would currently pass: the test fails when any " + "error rows are returned, and there are none._" + ) + return doc.render() + + doc.text(f"**Would fail:** ✗ — query returned {row_count} error row(s).") + doc.heading(2, "Source data preview (first row)") + first = rows[0] + columns = list(first.keys()) + if can_view_pii: + values = [first[c] for c in columns] + else: + values = ["[redacted]"] * len(columns) + doc.table(columns, [values]) + doc.text( + "_If saved as a CUSTOM test, this would currently fail because the SQL returned error " + "rows. Refine the query if some of those rows are false positives._" + ) + if not can_view_pii: + doc.text( + "_PII redacted: caller does not have `view_pii` on this project. Column names shown " + "so the LLM can iterate on shape; row values are masked._" + ) + return doc.render() + + +@with_database_session +@mcp_permission("edit") +def bulk_update_tests( + test_suite_id: str, + action: str, + table_name: str | None = None, + test_type: str | None = None, +) -> str: + """Enable or disable tests in a suite in bulk. + + Args: + test_suite_id: UUID of the test suite. + action: ``enable`` or ``disable``. + table_name: Optional table-name filter. Case-sensitive. + test_type: Optional test type name (e.g. ``Alpha Truncation``). + """ + try: + bulk_action = BulkAction(action) + except ValueError as err: + valid = ", ".join(f"`{a.value}`" for a in BulkAction) + raise MCPUserError(f"`action` must be one of: {valid}.") from err + suite = resolve_test_suite(test_suite_id) + tt_code = resolve_test_type(test_type) if test_type else None + + target = bulk_action is BulkAction.ENABLE + values: dict = {"test_active": target} + if target: + # Mirrors set_status_attribute: clearing the status when re-enabling so failed + # tests don't carry forward a stale "disabled because of X" marker. + values["test_definition_status"] = None + + where_clauses = [TestDefinition.test_suite_id == suite.id] + if table_name: + where_clauses.append(TestDefinition.table_name == table_name) + if tt_code: + where_clauses.append(TestDefinition.test_type == tt_code) + + stmt = ( + update(TestDefinition) + .where(*where_clauses) + .values(**values) + .returning(TestDefinition.id) + ) + session = get_current_session() + affected = session.execute(stmt).all() + count = len(affected) + + verb = "Enabled" if target else "Disabled" + filters = [] + if table_name: + filters.append(f"table_name=`{table_name}`") + if test_type: + filters.append(f"test_type=`{test_type}`") + filter_str = ", ".join(filters) if filters else "no filter" + + doc = MdDoc() + if count == 0: + doc.heading(1, "No tests matched") + doc.text( + f"No tests in suite `{suite.test_suite}` matched the filter ({filter_str}). Nothing changed." + ) + return doc.render() + + doc.heading(1, f"{verb} {count} test(s) in suite `{suite.test_suite}`") + doc.field("Filter", filter_str) + return doc.render() diff --git a/tests/unit/common/models/test_test_definition.py b/tests/unit/common/models/test_test_definition.py new file mode 100644 index 00000000..929c7eee --- /dev/null +++ b/tests/unit/common/models/test_test_definition.py @@ -0,0 +1,207 @@ +"""Tests for TestDefinition.validate() and TestDefinition.editable_fields().""" + +from unittest.mock import MagicMock + +import pytest + +from testgen.common.models.test_definition import ( + InvalidTestDefinitionFields, + Severity, + TestDefinition, + _required_fields_for, +) + + +def make_test_type( + code: str = "Alpha_Trunc", + scope: str = "column", + param_columns: set[str] | None = None, + default_parm_columns: str | None = "threshold_value", + default_parm_required: str | None = None, +) -> MagicMock: + """Build a TestType-shaped mock with the attributes the validator reads.""" + tt = MagicMock() + tt.test_type = code + tt.test_scope = scope + tt.param_columns = param_columns if param_columns is not None else {"threshold_value"} + tt.default_parm_columns = default_parm_columns + tt.default_parm_required = default_parm_required + return tt + + +def make_td(**fields) -> TestDefinition: + """Build a TestDefinition with the given fields set, nothing else.""" + td = TestDefinition() + for key, value in fields.items(): + setattr(td, key, value) + return td + + +# -- _required_fields_for ----------------------------------------------------- + + +def test_required_fields_column_scope_adds_column_name(): + tt = make_test_type(scope="column") + assert "column_name" in _required_fields_for(tt) + + +def test_required_fields_table_scope_no_column_name(): + tt = make_test_type(code="Row_Ct", scope="table", param_columns=set(), default_parm_columns=None) + assert "column_name" not in _required_fields_for(tt) + + +def test_required_fields_custom_query_when_in_param_columns(): + tt = make_test_type( + code="CUSTOM", + scope="custom", + param_columns={"custom_query", "match_column_names"}, + default_parm_columns="custom_query,match_column_names", + ) + assert "custom_query" in _required_fields_for(tt) + + +def test_required_fields_parses_default_parm_required(): + tt = make_test_type( + code="Metric_Trend", + scope="custom", + param_columns={"custom_query", "threshold_value", "baseline_value"}, + default_parm_columns="custom_query,threshold_value,baseline_value", + default_parm_required="Y,Y,N", + ) + required = _required_fields_for(tt) + assert "custom_query" in required + assert "threshold_value" in required + assert "baseline_value" not in required + + +def test_required_fields_null_required_means_no_extras(): + tt = make_test_type(scope="column", default_parm_required=None) + assert _required_fields_for(tt) == {"column_name"} + + +# -- TestDefinition.editable_fields ------------------------------------------- + + +def test_editable_fields_includes_base_set(): + tt = make_test_type(param_columns=set(), default_parm_columns=None) + td = make_td() + accepted = td.editable_fields(tt) + assert {"test_active", "severity", "lock_refresh", "flagged", "test_description"} <= accepted + + +def test_editable_fields_includes_param_columns(): + tt = make_test_type(param_columns={"threshold_value", "baseline_value"}) + td = make_td() + accepted = td.editable_fields(tt) + assert {"threshold_value", "baseline_value"} <= accepted + + +def test_editable_fields_does_not_leak_identity_or_internal_columns(): + tt = make_test_type(param_columns={"threshold_value"}) + td = make_td() + accepted = td.editable_fields(tt) + # Identity fields — callers must never set these via fields/extra_params + for forbidden in ("test_suite_id", "table_groups_id", "test_type", "schema_name"): + assert forbidden not in accepted + # Internal/system-managed columns + for forbidden in ("profile_run_id", "external_id", "prediction", "last_auto_gen_date"): + assert forbidden not in accepted + + +# -- TestDefinition.validate -------------------------------------------------- + + +def test_validate_happy_path(): + tt = make_test_type() + td = make_td(column_name="email", threshold_value="10") + td.validate(tt) # no raise + + +def test_validate_missing_required_column_name(): + tt = make_test_type(scope="column") + td = make_td(threshold_value="10") # no column_name + with pytest.raises(InvalidTestDefinitionFields) as exc_info: + td.validate(tt) + assert "column_name" in exc_info.value.errors + + +def test_validate_wrong_scope_column_name_rejected(): + tt = make_test_type(code="Row_Ct", scope="table", param_columns=set()) + td = make_td(column_name="email") + with pytest.raises(InvalidTestDefinitionFields) as exc_info: + td.validate(tt) + assert "column_name" in exc_info.value.errors + + +def test_validate_custom_scope_accepts_column_name_as_label(): + # CUSTOM uses column_name as a "Test Focus" label — must be accepted. + tt = make_test_type( + code="CUSTOM", + scope="custom", + param_columns={"custom_query"}, + default_parm_columns="custom_query", + ) + td = make_td(column_name="Negative Total Check", custom_query="SELECT 1") + td.validate(tt) # no raise + + +def test_validate_custom_query_not_accepted(): + tt = make_test_type() # param_columns = {threshold_value}; no custom_query allowed + td = make_td(column_name="email", threshold_value="10", custom_query="SELECT 1") + with pytest.raises(InvalidTestDefinitionFields) as exc_info: + td.validate(tt) + assert "custom_query" in exc_info.value.errors + + +def test_validate_severity_accepts_valid_strenum_values(): + tt = make_test_type() + for value in ("Fail", "Warning"): + td = make_td(column_name="email", threshold_value="10", severity=value) + td.validate(tt) + + +def test_validate_severity_rejects_invalid(): + tt = make_test_type() + td = make_td(column_name="email", threshold_value="10", severity="critical") + with pytest.raises(InvalidTestDefinitionFields) as exc_info: + td.validate(tt) + assert "severity" in exc_info.value.errors + + +def test_validate_severity_case_sensitive(): + # Per CLAUDE.md, case-sensitive — "fail" must be rejected. + tt = make_test_type() + td = make_td(column_name="email", threshold_value="10", severity="fail") + with pytest.raises(InvalidTestDefinitionFields) as exc_info: + td.validate(tt) + assert "severity" in exc_info.value.errors + + +def test_validate_severity_empty_string_treated_as_unset(): + tt = make_test_type() + td = make_td(column_name="email", threshold_value="10", severity="") + td.validate(tt) # empty severity is OK — falls back to test type default + + +def test_validate_aggregates_errors(): + tt = make_test_type(scope="column") + td = make_td(severity="critical", custom_query="SELECT 1") # no column_name + with pytest.raises(InvalidTestDefinitionFields) as exc_info: + td.validate(tt) + errors = exc_info.value.errors + assert {"column_name", "severity", "custom_query"} <= errors.keys() + + +def test_validate_empty_string_treats_required_field_as_cleared(): + tt = make_test_type(scope="column") + td = make_td(column_name="", threshold_value="10") + with pytest.raises(InvalidTestDefinitionFields) as exc_info: + td.validate(tt) + assert "column_name" in exc_info.value.errors + + +def test_severity_enum_value_accepted(): + # StrEnum subclasses str, so setting severity to the enum should pass validate. + tt = make_test_type() + td = make_td(column_name="email", threshold_value="10", severity=Severity.FAIL) + td.validate(tt) diff --git a/tests/unit/mcp/test_tools_test_definitions.py b/tests/unit/mcp/test_tools_test_definitions.py index 5dea0d03..e67e18a2 100644 --- a/tests/unit/mcp/test_tools_test_definitions.py +++ b/tests/unit/mcp/test_tools_test_definitions.py @@ -545,3 +545,590 @@ def test_list_test_types_filter_description(mock_tt, db_session_mock): assert "scope: table" in result assert "dimension: Completeness" in result + + +# -- create_test -------------------------------------------------------------- + + +def _make_suite(suite_id=None, table_groups_id=None): + suite = MagicMock() + suite.id = suite_id or uuid4() + suite.test_suite = "demo_suite" + suite.project_code = "demo" + suite.table_groups_id = table_groups_id or uuid4() + return suite + + +def _make_test_type( + code="Alpha_Trunc", + short_name="Alpha Truncation", + scope="column", + param_columns=None, + default_parm_columns="threshold_value", + default_parm_required=None, + default_severity="Fail", +): + tt = MagicMock() + tt.test_type = code + tt.test_name_short = short_name + tt.test_scope = scope + tt.param_columns = param_columns if param_columns is not None else {"threshold_value"} + tt.default_parm_columns = default_parm_columns + tt.default_parm_required = default_parm_required + tt.default_severity = default_severity + return tt + + +def _make_table_group(schema="public"): + tg = MagicMock() + tg.id = uuid4() + tg.table_group_schema = schema + return tg + + +def _make_td_summary(table_name="orders", column_name="email", severity="Warning"): + """Mock TestDefinitionSummary as returned by TestDefinition.get_for_project().""" + summary = MagicMock() + summary.id = uuid4() + summary.display_name = "Alpha Truncation" + summary.test_type = "Alpha_Trunc" + summary.test_name_short = "Alpha Truncation" + summary.table_name = table_name + summary.column_name = column_name + summary.schema_name = "demo" + summary.test_scope = "column" + summary.test_suite_id = uuid4() + summary.impact_dimension = None + summary.default_impact_dimension = "Conformance" + summary.dq_dimension = "Validity" + summary.test_active = True + summary.severity = severity + summary.default_severity = "Fail" + summary.lock_refresh = False + summary.export_to_observability = True + summary.flagged = False + summary.last_auto_gen_date = None + summary.last_manual_update = None + summary.default_parm_columns = "threshold_value" + summary.param_columns = {"threshold_value"} + summary.param_fields = [("threshold_value", "Maximum String Length at Baseline", "")] + summary.threshold_value = "64" + summary.custom_query = None + summary.match_schema_name = None + summary.match_table_name = None + summary.match_column_names = None + summary.match_subset_condition = None + summary.match_groupby_names = None + summary.match_having_condition = None + return summary + + +@patch("testgen.mcp.tools.test_definitions.TestDefinitionNote") +@patch("testgen.mcp.tools.test_definitions.TestDefinition") +@patch("testgen.mcp.tools.test_definitions.TableGroup") +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_type") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_create_test_happy_path( + mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, mock_td, mock_notes, db_session_mock, +): + suite = _make_suite() + mock_resolve_suite.return_value = suite + mock_resolve_tt.return_value = "Alpha_Trunc" + mock_tt_model.get.return_value = _make_test_type() + mock_tg.get.return_value = _make_table_group() + + saved = MagicMock() + saved.id = uuid4() + mock_td.return_value = saved + mock_td.get_for_project.return_value = _make_td_summary() + mock_notes.get_notes.return_value = [] + + from testgen.mcp.tools.test_definitions import create_test + + result = create_test( + test_suite_id=str(uuid4()), + test_type="Alpha Truncation", + table_name="orders", + column_name="email", + threshold_value="64", + severity="Warning", + ) + + # New shared body: entity-first heading + "Created in suite" lead-in + assert "Created" in result + assert "Alpha Truncation on `email` in `orders`" in result + # Parameters table uses the test type's prompt, not a hardcoded label + assert "Maximum String Length at Baseline" in result + assert "64" in result + assert "Warning" in result + saved.save.assert_called_once() + + +@patch("testgen.mcp.tools.test_definitions.TableGroup") +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_type") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_create_test_column_scope_requires_column_name( + mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, db_session_mock +): + mock_resolve_suite.return_value = _make_suite() + mock_resolve_tt.return_value = "Alpha_Trunc" + mock_tt_model.get.return_value = _make_test_type() # column scope + mock_tg.get.return_value = _make_table_group() + + from testgen.mcp.tools.test_definitions import create_test + + with pytest.raises(MCPUserError) as exc_info: + create_test( + test_suite_id=str(uuid4()), + test_type="Alpha Truncation", + table_name="orders", + threshold_value="64", + ) + assert "column_name" in str(exc_info.value) + assert "rejected" in str(exc_info.value).lower() + + +@patch("testgen.mcp.tools.test_definitions.TableGroup") +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_type") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_create_test_custom_query_not_accepted_on_alpha_trunc( + mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, db_session_mock +): + mock_resolve_suite.return_value = _make_suite() + mock_resolve_tt.return_value = "Alpha_Trunc" + mock_tt_model.get.return_value = _make_test_type() # param_columns = {threshold_value} + mock_tg.get.return_value = _make_table_group() + + from testgen.mcp.tools.test_definitions import create_test + + with pytest.raises(MCPUserError) as exc_info: + create_test( + test_suite_id=str(uuid4()), + test_type="Alpha Truncation", + table_name="orders", + column_name="email", + threshold_value="64", + custom_query="SELECT 1", + ) + assert "custom_query" in str(exc_info.value) + + +@patch("testgen.mcp.tools.test_definitions.TestDefinitionNote") +@patch("testgen.mcp.tools.test_definitions.TestDefinition") +@patch("testgen.mcp.tools.test_definitions.TableGroup") +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_type") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_create_test_extra_params_pass_through( + mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, mock_td, mock_notes, db_session_mock, +): + """extra_params adds fields not in the explicit kwargs (e.g. window_days).""" + mock_resolve_suite.return_value = _make_suite() + mock_resolve_tt.return_value = "Some_Trend" + # Test type accepts threshold_value AND window_days + mock_tt_model.get.return_value = _make_test_type( + code="Some_Trend", + param_columns={"threshold_value", "window_days"}, + default_parm_columns="threshold_value,window_days", + ) + mock_tg.get.return_value = _make_table_group() + saved_td = MagicMock(id=uuid4()) + saved_td.editable_fields.return_value = { + "test_active", "severity", "lock_refresh", "flagged", "test_description", + "threshold_value", "window_days", + } + mock_td.return_value = saved_td + mock_td.get_for_project.return_value = _make_td_summary() + mock_notes.get_notes.return_value = [] + + from testgen.mcp.tools.test_definitions import create_test + + create_test( + test_suite_id=str(uuid4()), + test_type="Some Trend", + table_name="orders", + column_name="email", + threshold_value="10", + extra_params={"window_days": "7"}, + ) + + # threshold_value (from kwarg) and window_days (from extras) were both setattr'd on the TD + assert saved_td.threshold_value == "10" + assert saved_td.window_days == "7" + saved_td.validate.assert_called_once() + saved_td.save.assert_called_once() + + +@patch("testgen.mcp.tools.test_definitions.TableGroup") +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_type") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_create_test_extra_params_conflict_rejected( + mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, db_session_mock, +): + """Passing the same field via both kwarg and extra_params is rejected.""" + mock_resolve_suite.return_value = _make_suite() + mock_resolve_tt.return_value = "Alpha_Trunc" + mock_tt_model.get.return_value = _make_test_type() + mock_tg.get.return_value = _make_table_group() + + from testgen.mcp.tools.test_definitions import create_test + + with pytest.raises(MCPUserError, match="both as named arguments and in"): + create_test( + test_suite_id=str(uuid4()), + test_type="Alpha Truncation", + table_name="orders", + column_name="email", + threshold_value="10", + extra_params={"threshold_value": "20"}, + ) + + +@patch("testgen.mcp.tools.test_definitions.TableGroup") +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_type") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_create_test_extra_params_unknown_field_rejected_via_validator( + mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, db_session_mock, +): + """Unknown field in extra_params surfaces through the validator's wrong-scope/unaccepted rules.""" + mock_resolve_suite.return_value = _make_suite() + mock_resolve_tt.return_value = "Alpha_Trunc" + mock_tt_model.get.return_value = _make_test_type() # param_columns = {threshold_value} + mock_tg.get.return_value = _make_table_group() + + from testgen.mcp.tools.test_definitions import create_test + + # custom_query isn't accepted by Alpha_Trunc — validator should reject + with pytest.raises(MCPUserError) as exc_info: + create_test( + test_suite_id=str(uuid4()), + test_type="Alpha Truncation", + table_name="orders", + column_name="email", + threshold_value="10", + extra_params={"custom_query": "SELECT 1"}, + ) + assert "custom_query" in str(exc_info.value) + + +@patch("testgen.mcp.tools.test_definitions.TableGroup") +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_type") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_create_test_severity_invalid( + mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, db_session_mock +): + mock_resolve_suite.return_value = _make_suite() + mock_resolve_tt.return_value = "Alpha_Trunc" + mock_tt_model.get.return_value = _make_test_type() + mock_tg.get.return_value = _make_table_group() + + from testgen.mcp.tools.test_definitions import create_test + + with pytest.raises(MCPUserError) as exc_info: + create_test( + test_suite_id=str(uuid4()), + test_type="Alpha Truncation", + table_name="orders", + column_name="email", + threshold_value="64", + severity="critical", + ) + assert "severity" in str(exc_info.value) + + +# -- update_test -------------------------------------------------------------- + + +def _make_td_orm(test_type="Alpha_Trunc", threshold_value="64", severity="Warning"): + td = MagicMock() + td.id = uuid4() + td.test_type = test_type + td.threshold_value = threshold_value + td.severity = severity + td.test_active = True + td.lock_refresh = False + td.flagged = False + # Mirror TestDefinition.editable_fields(tt) for an Alpha_Trunc-shaped test type + td.editable_fields.return_value = { + "test_active", "severity", "lock_refresh", "flagged", "test_description", + "threshold_value", + } + return td + + +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_definition") +def test_update_test_happy_path(mock_resolve_td, mock_tt_model, db_session_mock): + td = _make_td_orm() + mock_resolve_td.return_value = td + mock_tt_model.get.return_value = _make_test_type() + + from testgen.mcp.tools.test_definitions import update_test + + result = update_test(str(td.id), fields={"threshold_value": "80"}) + + assert "updated" in result.lower() + assert "threshold_value" in result + assert "80" in result + assert td.threshold_value == "80" + td.save.assert_called_once() + + +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_definition") +def test_update_test_empty_fields_rejected(mock_resolve_td, mock_tt_model, db_session_mock): + td = _make_td_orm() + mock_resolve_td.return_value = td + mock_tt_model.get.return_value = _make_test_type() + + from testgen.mcp.tools.test_definitions import update_test + + with pytest.raises(MCPUserError): + update_test(str(td.id), fields={}) + td.save.assert_not_called() + + +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_definition") +def test_update_test_unknown_field_rejected_no_partial(mock_resolve_td, mock_tt_model, db_session_mock): + td = _make_td_orm() + mock_resolve_td.return_value = td + mock_tt_model.get.return_value = _make_test_type() + + from testgen.mcp.tools.test_definitions import update_test + + with pytest.raises(MCPUserError) as exc_info: + # threshold_value is valid, table_name is not — must reject ALL + update_test(str(td.id), fields={"threshold_value": "80", "table_name": "new"}) + assert "table_name" in str(exc_info.value) + # td.threshold_value should NOT have been mutated + assert td.threshold_value == "64" + td.save.assert_not_called() + + +@patch("testgen.mcp.tools.test_definitions.TestType") +@patch("testgen.mcp.tools.test_definitions.resolve_test_definition") +def test_update_test_multi_field(mock_resolve_td, mock_tt_model, db_session_mock): + td = _make_td_orm() + mock_resolve_td.return_value = td + mock_tt_model.get.return_value = _make_test_type() + + from testgen.mcp.tools.test_definitions import update_test + + result = update_test( + str(td.id), + fields={"threshold_value": "80", "severity": "Fail", "test_active": False}, + ) + assert "3 field" in result + td.save.assert_called_once() + + +# -- validate_custom_test ----------------------------------------------------- + + +@patch("testgen.mcp.tools.test_definitions.fetch_from_target_db") +@patch("testgen.mcp.tools.test_definitions.Connection") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_validate_custom_test_would_pass_when_no_rows( + mock_resolve_suite, mock_conn, mock_fetch, db_session_mock +): + mock_resolve_suite.return_value = _make_suite() + conn = MagicMock() + conn.connection_name = "warehouse" + conn.sql_flavor_code = "snowflake" + conn.sql_flavor = "snowflake" + mock_conn.get_by_table_group.return_value = conn + mock_fetch.return_value = [] + + from testgen.mcp.tools.test_definitions import validate_custom_test + + result = validate_custom_test(str(uuid4()), "SELECT 1 WHERE 1=0") + + assert "ran successfully" in result.lower() + assert "would pass" in result.lower() + assert "0 error rows" in result + + +@patch("testgen.mcp.permissions._compute_project_permissions") +@patch("testgen.mcp.tools.test_definitions.fetch_from_target_db") +@patch("testgen.mcp.tools.test_definitions.Connection") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_validate_custom_test_would_fail_shows_preview_with_view_pii( + mock_resolve_suite, mock_conn, mock_fetch, mock_compute, db_session_mock +): + # Grant view_pii on "demo" so values are visible in the preview. + from testgen.mcp.permissions import ProjectPermissions + + perms = MagicMock(spec=ProjectPermissions) + perms.allowed_codes = ["demo"] + perms.codes_allowed_to.return_value = ["demo"] + perms.has_access.side_effect = lambda code: code == "demo" + mock_compute.return_value = perms + + mock_resolve_suite.return_value = _make_suite() + conn = MagicMock() + conn.connection_name = "warehouse" + conn.sql_flavor_code = "snowflake" + conn.sql_flavor = "snowflake" + mock_conn.get_by_table_group.return_value = conn + + row = MagicMock() + row.keys.return_value = ["order_id", "amount"] + row.__getitem__.side_effect = lambda k: {"order_id": "ORD-123", "amount": "-45.99"}[k] + mock_fetch.return_value = [row, row, row] + + from testgen.mcp.tools.test_definitions import validate_custom_test + + result = validate_custom_test(str(uuid4()), "SELECT * FROM orders WHERE amount < 0") + + assert "would fail" in result.lower() + assert "3 error row" in result + assert "order_id" in result + assert "ORD-123" in result + # No redaction banner when view_pii is granted + assert "[redacted]" not in result + + +@patch("testgen.mcp.tools.test_definitions.fetch_from_target_db") +@patch("testgen.mcp.tools.test_definitions.Connection") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_validate_custom_test_redacts_when_no_view_pii( + mock_resolve_suite, mock_conn, mock_fetch, db_session_mock +): + # Default fixture user has role_a with edit but not view_pii (view_pii not in test matrix → empty) + mock_resolve_suite.return_value = _make_suite() + conn = MagicMock() + conn.connection_name = "warehouse" + conn.sql_flavor_code = "snowflake" + conn.sql_flavor = "snowflake" + mock_conn.get_by_table_group.return_value = conn + + row = MagicMock() + row.keys.return_value = ["order_id", "customer_email"] + row.__getitem__.side_effect = lambda k: {"order_id": "ORD-123", "customer_email": "jane@example.com"}[k] + mock_fetch.return_value = [row] + + from testgen.mcp.tools.test_definitions import validate_custom_test + + result = validate_custom_test(str(uuid4()), "SELECT * FROM orders") + + # Column names always visible + assert "order_id" in result + assert "customer_email" in result + # Values redacted because view_pii not granted in the default test matrix + assert "[redacted]" in result + assert "jane@example.com" not in result + assert "ORD-123" not in result + + +@patch("testgen.mcp.tools.test_definitions.fetch_from_target_db") +@patch("testgen.mcp.tools.test_definitions.Connection") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_validate_custom_test_sql_error_surfaced( + mock_resolve_suite, mock_conn, mock_fetch, db_session_mock +): + mock_resolve_suite.return_value = _make_suite() + conn = MagicMock() + conn.connection_name = "warehouse" + conn.sql_flavor_code = "postgresql" + conn.sql_flavor = "postgresql" + mock_conn.get_by_table_group.return_value = conn + mock_fetch.side_effect = Exception('syntax error at or near "FROMM"') + + from testgen.mcp.tools.test_definitions import validate_custom_test + + result = validate_custom_test(str(uuid4()), "SELECT * FROMM orders") + + assert "did not execute" in result.lower() + assert "syntax error" in result + + +@patch("testgen.mcp.tools.test_definitions.Connection") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_validate_custom_test_missing_connection(mock_resolve_suite, mock_conn, db_session_mock): + mock_resolve_suite.return_value = _make_suite() + mock_conn.get_by_table_group.return_value = None + + from testgen.mcp.tools.test_definitions import validate_custom_test + + with pytest.raises(MCPUserError, match="No connection"): + validate_custom_test(str(uuid4()), "SELECT 1") + + +# -- bulk_update_tests -------------------------------------------------------- + + +@patch("testgen.mcp.tools.test_definitions.get_current_session") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_bulk_update_tests_disable_no_filter(mock_resolve_suite, mock_session, db_session_mock): + mock_resolve_suite.return_value = _make_suite() + result_mock = MagicMock() + result_mock.all.return_value = [(uuid4(),), (uuid4(),), (uuid4(),)] + mock_session.return_value.execute.return_value = result_mock + + from testgen.mcp.tools.test_definitions import bulk_update_tests + + result = bulk_update_tests(test_suite_id=str(uuid4()), action="disable") + + assert "Disabled" in result + assert "3 test" in result + assert "no filter" in result + + +@patch("testgen.mcp.tools.test_definitions.get_current_session") +@patch("testgen.mcp.tools.test_definitions.resolve_test_type") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_bulk_update_tests_enable_with_table_filter( + mock_resolve_suite, mock_resolve_tt, mock_session, db_session_mock +): + mock_resolve_suite.return_value = _make_suite() + result_mock = MagicMock() + result_mock.all.return_value = [(uuid4(),)] + mock_session.return_value.execute.return_value = result_mock + + from testgen.mcp.tools.test_definitions import bulk_update_tests + + result = bulk_update_tests( + test_suite_id=str(uuid4()), action="enable", table_name="legacy_orders" + ) + + assert "Enabled" in result + assert "legacy_orders" in result + mock_resolve_tt.assert_not_called() # not called when test_type filter absent + + +@patch("testgen.mcp.tools.test_definitions.get_current_session") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_bulk_update_tests_invalid_action(mock_resolve_suite, mock_session, db_session_mock): + mock_resolve_suite.return_value = _make_suite() + + from testgen.mcp.tools.test_definitions import bulk_update_tests + + with pytest.raises(MCPUserError, match="`action`"): + bulk_update_tests(test_suite_id=str(uuid4()), action="toggle") + + # Suite resolution happens before action validation in current code path? + # Actually, action is validated first; resolve_test_suite shouldn't have been called. + mock_resolve_suite.assert_not_called() + + +@patch("testgen.mcp.tools.test_definitions.get_current_session") +@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") +def test_bulk_update_tests_no_match(mock_resolve_suite, mock_session, db_session_mock): + mock_resolve_suite.return_value = _make_suite() + result_mock = MagicMock() + result_mock.all.return_value = [] + mock_session.return_value.execute.return_value = result_mock + + from testgen.mcp.tools.test_definitions import bulk_update_tests + + result = bulk_update_tests(test_suite_id=str(uuid4()), action="disable", table_name="nonexistent") + + assert "No tests matched" in result + assert "nonexistent" in result From 55cca1aa71796c6f523b7b510d7fe90f2e3caa77 Mon Sep 17 00:00:00 2001 From: Luis Date: Thu, 7 May 2026 21:50:50 -0400 Subject: [PATCH 06/58] refactor(mcp): add get_column_profile_detail tool New tool to get the deep profile for one column. --- testgen/common/models/data_column.py | 234 ++++++++ testgen/common/models/profile_result.py | 62 ++- testgen/mcp/server.py | 9 +- testgen/mcp/tools/common.py | 15 + testgen/mcp/tools/markdown.py | 8 +- testgen/mcp/tools/profiling.py | 265 ++++++++- testgen/mcp/tools/reference.py | 125 +++++ tests/unit/mcp/test_model_data_column.py | 240 ++++++++ tests/unit/mcp/test_tools_common.py | 55 +- tests/unit/mcp/test_tools_profiling.py | 669 ++++++++++++++++++++++- tests/unit/mcp/test_tools_reference.py | 87 +++ 11 files changed, 1737 insertions(+), 32 deletions(-) create mode 100644 tests/unit/mcp/test_model_data_column.py diff --git a/testgen/common/models/data_column.py b/testgen/common/models/data_column.py index 0280a28b..81d9c125 100644 --- a/testgen/common/models/data_column.py +++ b/testgen/common/models/data_column.py @@ -17,9 +17,11 @@ ) from sqlalchemy.dialects import postgresql +from testgen.common.models import get_current_session from testgen.common.models.entity import Entity, EntityMinimal from testgen.common.models.hygiene_issue import HygieneIssue from testgen.common.models.profile_result import ProfileResult +from testgen.common.models.profiling_run import ProfilingRun @dataclass @@ -40,6 +42,88 @@ class ColumnProfileSummary(EntityMinimal): hygiene_issue_count: int +@dataclass +class ColumnProfileDetail(EntityMinimal): + """L2 column profiling detail — header fields plus type-specific stats and run identity.""" + + # Identity + column_name: str + table_name: str + schema_name: str | None + # Types & metadata + general_type: str | None + column_type: str | None + db_data_type: str | None + functional_data_type: str | None + datatype_suggestion: str | None + functional_table_type: str | None + pii_flag: str | None + critical_data_element: bool | None + # Counts + record_ct: int | None + value_ct: int | None + distinct_value_ct: int | None + null_value_ct: int | None + filled_value_ct: int | None + zero_value_ct: int | None + # Alpha + min_length: int | None + max_length: int | None + avg_length: float | None + min_text: str | None + max_text: str | None + top_freq_values: str | None + top_patterns: str | None + distinct_std_value_ct: int | None + distinct_pattern_ct: int | None + std_pattern_match: str | None + mixed_case_ct: int | None + lower_case_ct: int | None + upper_case_ct: int | None + non_alpha_ct: int | None + includes_digit_ct: int | None + numeric_ct: int | None + date_ct: int | None + quoted_value_ct: int | None + lead_space_ct: int | None + embedded_space_ct: int | None + avg_embedded_spaces: float | None + zero_length_ct: int | None + # Numeric + min_value: float | None + min_value_over_0: float | None + max_value: float | None + avg_value: float | None + stdev_value: float | None + percentile_25: float | None + percentile_50: float | None + percentile_75: float | None + # Date + min_date: datetime | None + max_date: datetime | None + before_1yr_date_ct: int | None + before_5yr_date_ct: int | None + before_20yr_date_ct: int | None + within_1yr_date_ct: int | None + within_1mo_date_ct: int | None + future_date_ct: int | None + # Boolean + boolean_true_ct: int | None + # Per-column profiling failure + query_error: str | None + # Scores & hygiene + dq_score_profiling: float | None + dq_score_testing: float | None + hygiene_issue_count: int + # Run identity + profile_run_id: UUID | None + profile_run_je_id: UUID | None + profile_run_status: str | None + profile_run_started_at: datetime | None + profile_run_ended_at: datetime | None + profile_run_log_message: str | None + + class DataColumnChars(Entity): __tablename__ = "data_column_chars" @@ -166,3 +250,153 @@ def list_for_table_group( ) return cls._paginate(query, page=page, limit=limit, data_class=ColumnProfileSummary) + + @classmethod + def get_column_detail( + cls, + table_groups_id: UUID, + table_name: str, + column_name: str, + profiling_run_id: UUID | None = None, + ) -> ColumnProfileDetail | None: + """Fetch the L2 profile detail for a single column. + + When ``profiling_run_id`` is None, joins on the column's + ``last_complete_profile_run_id`` so the caller gets the latest run. + Returns None when the column does not exist in the table group. + """ + from testgen.common.models.data_table import DataTable + + profile_run_filter = ( + ProfileResult.profile_run_id == profiling_run_id + if profiling_run_id is not None + else ProfileResult.profile_run_id == cls.last_complete_profile_run_id + ) + + hygiene_subq = ( + select( + HygieneIssue.profile_run_id.label("profile_run_id"), + HygieneIssue.schema_name.label("schema_name"), + HygieneIssue.table_name.label("table_name"), + HygieneIssue.column_name.label("column_name"), + func.count().label("hygiene_issue_count"), + ) + .where( + HygieneIssue.table_groups_id == table_groups_id, + func.coalesce(HygieneIssue.disposition, "Confirmed") == "Confirmed", + ) + .group_by( + HygieneIssue.profile_run_id, + HygieneIssue.schema_name, + HygieneIssue.table_name, + HygieneIssue.column_name, + ) + .subquery() + ) + + cde_coalesced = case( + (cls.critical_data_element.is_(True), True), + (DataTable.critical_data_element.is_(True), True), + else_=False, + ).label("critical_data_element") + + query = ( + select( + cls.column_name, + cls.table_name, + cls.schema_name, + cls.general_type, + ProfileResult.column_type, + cls.db_data_type, + cls.functional_data_type, + ProfileResult.datatype_suggestion, + ProfileResult.functional_table_type, + cls.pii_flag, + cde_coalesced, + ProfileResult.record_ct, + ProfileResult.value_ct, + ProfileResult.distinct_value_ct, + ProfileResult.null_value_ct, + ProfileResult.filled_value_ct, + ProfileResult.zero_value_ct, + ProfileResult.min_length, + ProfileResult.max_length, + ProfileResult.avg_length, + ProfileResult.min_text, + ProfileResult.max_text, + ProfileResult.top_freq_values, + ProfileResult.top_patterns, + ProfileResult.distinct_std_value_ct, + ProfileResult.distinct_pattern_ct, + ProfileResult.std_pattern_match, + ProfileResult.mixed_case_ct, + ProfileResult.lower_case_ct, + ProfileResult.upper_case_ct, + ProfileResult.non_alpha_ct, + ProfileResult.includes_digit_ct, + ProfileResult.numeric_ct, + ProfileResult.date_ct, + ProfileResult.quoted_value_ct, + ProfileResult.lead_space_ct, + ProfileResult.embedded_space_ct, + ProfileResult.avg_embedded_spaces, + ProfileResult.zero_length_ct, + ProfileResult.min_value, + ProfileResult.min_value_over_0, + ProfileResult.max_value, + ProfileResult.avg_value, + ProfileResult.stdev_value, + ProfileResult.percentile_25, + ProfileResult.percentile_50, + ProfileResult.percentile_75, + ProfileResult.min_date, + ProfileResult.max_date, + ProfileResult.before_1yr_date_ct, + ProfileResult.before_5yr_date_ct, + ProfileResult.before_20yr_date_ct, + ProfileResult.within_1yr_date_ct, + ProfileResult.within_1mo_date_ct, + ProfileResult.future_date_ct, + ProfileResult.boolean_true_ct, + ProfileResult.query_error, + cls.dq_score_profiling, + cls.dq_score_testing, + func.coalesce(hygiene_subq.c.hygiene_issue_count, 0).label("hygiene_issue_count"), + ProfilingRun.id.label("profile_run_id"), + ProfilingRun.job_execution_id.label("profile_run_je_id"), + ProfilingRun.status.label("profile_run_status"), + ProfilingRun.profiling_starttime.label("profile_run_started_at"), + ProfilingRun.profiling_endtime.label("profile_run_ended_at"), + ProfilingRun.log_message.label("profile_run_log_message"), + ) + .outerjoin(DataTable, DataTable.id == cls.table_id) + .outerjoin( + ProfileResult, + and_( + profile_run_filter, + ProfileResult.schema_name == cls.schema_name, + ProfileResult.table_name == cls.table_name, + ProfileResult.column_name == cls.column_name, + ), + ) + .outerjoin( + hygiene_subq, + and_( + hygiene_subq.c.profile_run_id == ProfileResult.profile_run_id, + hygiene_subq.c.schema_name == cls.schema_name, + hygiene_subq.c.table_name == cls.table_name, + hygiene_subq.c.column_name == cls.column_name, + ), + ) + .outerjoin(ProfilingRun, ProfilingRun.id == ProfileResult.profile_run_id) + .where( + cls.table_groups_id == table_groups_id, + cls.table_name == table_name, + cls.column_name == column_name, + cls.drop_date.is_(None), + ) + .limit(1) + ) + + row = get_current_session().execute(query).mappings().first() + return ColumnProfileDetail(**row) if row else None diff --git a/testgen/common/models/profile_result.py b/testgen/common/models/profile_result.py index 5826e63c..31f37337 100644 --- a/testgen/common/models/profile_result.py +++ b/testgen/common/models/profile_result.py @@ -1,6 +1,7 @@ +from datetime import datetime from uuid import UUID, uuid4 -from sqlalchemy import BigInteger, Column, ForeignKey, Integer, String, asc +from sqlalchemy import BigInteger, Column, Float, ForeignKey, Integer, Numeric, String, asc from sqlalchemy.dialects import postgresql from testgen.common.models.entity import Entity @@ -18,9 +19,11 @@ class ProfileResult(Entity): position: int = Column(Integer) general_type: str | None = Column(String) + column_type: str | None = Column(String) + db_data_type: str | None = Column(String) functional_data_type: str | None = Column(String) + functional_table_type: str | None = Column(String) datatype_suggestion: str | None = Column(String) - db_data_type: str | None = Column(String) pii_flag: str | None = Column(String(50)) record_ct: int | None = Column(BigInteger) @@ -28,8 +31,57 @@ class ProfileResult(Entity): null_value_ct: int | None = Column(BigInteger) distinct_value_ct: int | None = Column(BigInteger) filled_value_ct: int | None = Column(BigInteger) + zero_value_ct: int | None = Column(BigInteger) - _default_order_by = (asc(position), asc(column_name)) + # Alpha-specific + min_length: int | None = Column(Integer) + max_length: int | None = Column(Integer) + avg_length: float | None = Column(Float) + min_text: str | None = Column(String) + max_text: str | None = Column(String) + top_freq_values: str | None = Column(String) + top_patterns: str | None = Column(String) + distinct_std_value_ct: int | None = Column(BigInteger) + distinct_pattern_ct: int | None = Column(BigInteger) + std_pattern_match: str | None = Column(String) + mixed_case_ct: int | None = Column(BigInteger) + lower_case_ct: int | None = Column(BigInteger) + upper_case_ct: int | None = Column(BigInteger) + non_alpha_ct: int | None = Column(BigInteger) + includes_digit_ct: int | None = Column(BigInteger) + numeric_ct: int | None = Column(BigInteger) + date_ct: int | None = Column(BigInteger) + quoted_value_ct: int | None = Column(BigInteger) + lead_space_ct: int | None = Column(BigInteger) + embedded_space_ct: int | None = Column(BigInteger) + avg_embedded_spaces: float | None = Column(Float) + zero_length_ct: int | None = Column(BigInteger) + + # Numeric-specific + min_value: float | None = Column(Float) + min_value_over_0: float | None = Column(Float) + max_value: float | None = Column(Float) + avg_value: float | None = Column(Float) + stdev_value: float | None = Column(Float) + percentile_25: float | None = Column(Float) + percentile_50: float | None = Column(Float) + percentile_75: float | None = Column(Float) + fractional_sum: float | None = Column(Numeric(38, 6)) - # Additional columns exist on this table (type-specific profile stats). - # They'll be mapped here as new MCP tools need them (L2+). + # Date-specific + min_date: datetime | None = Column(postgresql.TIMESTAMP) + max_date: datetime | None = Column(postgresql.TIMESTAMP) + before_1yr_date_ct: int | None = Column(BigInteger) + before_5yr_date_ct: int | None = Column(BigInteger) + before_20yr_date_ct: int | None = Column(BigInteger) + within_1yr_date_ct: int | None = Column(BigInteger) + within_1mo_date_ct: int | None = Column(BigInteger) + future_date_ct: int | None = Column(BigInteger) + + # Boolean-specific + boolean_true_ct: int | None = Column(BigInteger) + + # Per-column profiling failure (independent of run-level status) + query_error: str | None = Column(String) + + _default_order_by = (asc(position), asc(column_name)) diff --git a/testgen/mcp/server.py b/testgen/mcp/server.py index 2b0539fb..0d630c40 100644 --- a/testgen/mcp/server.py +++ b/testgen/mcp/server.py @@ -34,7 +34,10 @@ ALWAYS look them up using either the `testgen://test-types` resource or the `get_test_type()` tool. Hygiene issue types similarly have specific meanings. ALWAYS look them up using the -`testgen://hygiene-issue-types` resource. +`testgen://hygiene-issue-types` resource.q + +Column profile fields are type-specific (different stats per Alpha / Numeric / Date / Boolean / Other). +ALWAYS look them up using the `testgen://column-profile-fields` resource. INVESTIGATING FAILURES @@ -149,6 +152,7 @@ def build_mcp_server( update_hygiene_issue, ) from testgen.mcp.tools.profiling import ( + get_column_profile_detail, get_profiling_run, get_table, list_column_profiles, @@ -156,6 +160,7 @@ def build_mcp_server( list_profiling_summaries, ) from testgen.mcp.tools.reference import ( + column_profile_fields_resource, get_test_type, glossary_resource, hygiene_issue_types_resource, @@ -222,6 +227,7 @@ def safe_prompt(fn): safe_tool(list_profiling_summaries) safe_tool(list_profiling_runs) safe_tool(get_profiling_run) + safe_tool(get_column_profile_detail) safe_tool(run_tests) safe_tool(run_profiling) safe_tool(cancel_test_run) @@ -235,6 +241,7 @@ def safe_prompt(fn): # Resources safe_resource("testgen://test-types", test_types_resource) safe_resource("testgen://hygiene-issue-types", hygiene_issue_types_resource) + safe_resource("testgen://column-profile-fields", column_profile_fields_resource) safe_resource("testgen://glossary", glossary_resource) # Prompts diff --git a/testgen/mcp/tools/common.py b/testgen/mcp/tools/common.py index 4ddb39a2..8b445723 100644 --- a/testgen/mcp/tools/common.py +++ b/testgen/mcp/tools/common.py @@ -6,6 +6,7 @@ from testgen.common.enums import ImpactDimension, QualityDimension from testgen.common.models.hygiene_issue import Disposition, HygieneIssueType, IssueLikelihood, PiiRisk from testgen.common.models.job_execution import JobStatus +from testgen.common.models.profiling_run import ProfilingRun from testgen.common.models.scheduler import JobSchedule from testgen.common.models.table_group import TableGroup from testgen.common.models.test_definition import TestType @@ -250,3 +251,17 @@ def resolve_test_suite(test_suite_id: str) -> TestSuite: if suite is None: raise MCPResourceNotAccessible("Test suite", test_suite_id) return suite + + +def resolve_profiling_run(job_execution_id: str) -> ProfilingRun: + """Resolve a profiling run by id-or-JE-id, scoped to allowed projects. + + Collapses missing-or-inaccessible into a single ``MCPResourceNotAccessible`` + so callers don't leak existence of runs they shouldn't see. + """ + run_uuid = parse_uuid(job_execution_id, "job_execution_id") + run = ProfilingRun.get_by_id_or_job(run_uuid) + perms = get_project_permissions() + if run is None or not perms.has_access(run.project_code): + raise MCPResourceNotAccessible("Profiling run", job_execution_id) + return run diff --git a/testgen/mcp/tools/markdown.py b/testgen/mcp/tools/markdown.py index ceac0ded..23e3c6d7 100644 --- a/testgen/mcp/tools/markdown.py +++ b/testgen/mcp/tools/markdown.py @@ -44,7 +44,6 @@ def _format_dt(value: object) -> str | None: return value[:16].replace("T", " ") + " UTC" return None - def _format_part(value: object) -> str: """Format a single value for text() parts — datetime-aware, no escaping.""" if value is None: @@ -52,6 +51,11 @@ def _format_part(value: object) -> str: return dt_str if (dt_str := _format_dt(value)) else str(value) +def _format_boolean(value: object) -> str | None: + if isinstance(value, bool): + return "Yes" if value else "No" + return None + # --------------------------------------------------------------------------- # MdDoc # --------------------------------------------------------------------------- @@ -204,6 +208,8 @@ def _format_field_value(value: object, *, code: bool = False) -> str: return "\u2014" if dt_str := _format_dt(value): return MdDoc.code(dt_str) if code else dt_str + if bool_str := _format_boolean(value): + return MdDoc.code(bool_str) if code else bool_str s = str(value) return MdDoc.code(s) if code else s diff --git a/testgen/mcp/tools/profiling.py b/testgen/mcp/tools/profiling.py index 1d8bbb1e..aea98447 100644 --- a/testgen/mcp/tools/profiling.py +++ b/testgen/mcp/tools/profiling.py @@ -1,12 +1,14 @@ +import dataclasses from uuid import UUID from testgen.common.models import with_database_session -from testgen.common.models.data_column import ColumnProfileSummary, DataColumnChars +from testgen.common.models.data_column import ColumnProfileDetail, ColumnProfileSummary, DataColumnChars from testgen.common.models.data_table import DataTable from testgen.common.models.job_execution import JobExecution from testgen.common.models.profiling_run import ProfilingRun, ProfilingRunSummary from testgen.common.models.scheduler import RUN_PROFILE_JOB_KEY from testgen.common.models.table_group import TableGroup, TableGroupSummary +from testgen.common.pii_masking import mask_profiling_pii from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.permissions import get_project_permissions, mcp_permission from testgen.mcp.tools.common import ( @@ -17,6 +19,7 @@ next_scheduled_run, parse_run_status_filter, parse_uuid, + resolve_profiling_run, resolve_table_group, validate_limit, validate_page, @@ -97,9 +100,8 @@ def list_column_profiles( profiling_run_id: UUID | None = None if job_execution_id: - run_uuid = parse_uuid(job_execution_id, "job_execution_id") - profiling_run = ProfilingRun.get_by_id_or_job(run_uuid) - if profiling_run is None or profiling_run.table_groups_id != tg.id: + profiling_run = resolve_profiling_run(job_execution_id) + if profiling_run.table_groups_id != tg.id: raise MCPResourceNotAccessible("Profiling run", job_execution_id) profiling_run_id = profiling_run.id @@ -216,9 +218,9 @@ def list_profiling_summaries( def _format_pii(value: str | None) -> str | None: """Render a `pii_flag` value as a human label. Mirrors `PiiDisplay` in metadata_tags.js.""" if not value: - return None + return "No" if value == "MANUAL": - return "PII" + return "Yes" risk, _, rest = value.partition("/") type_code, _, detail = rest.partition("/") risk_label = _PII_RISK_MAP.get(risk, "Moderate") @@ -228,7 +230,7 @@ def _format_pii(value: str | None) -> str | None: caption += f" - {type_label}" if detail and detail != type_label: caption += f" / {detail}" - return f"PII ({caption})" + return f"Yes ({caption})" def _render_column_profile_row(c: ColumnProfileSummary) -> list: @@ -459,3 +461,252 @@ def _render_table_group_summary(doc: MdDoc, s: TableGroupSummary) -> None: doc.field("Profiling Run", s.latest_profile_job_execution_id, code=True) if s.monitor_lookback_end: doc.field("Last monitored", s.monitor_lookback_end) + + +# --------------------------------------------------------------------------- +# get_column_profile_detail +# --------------------------------------------------------------------------- + +# Friendly labels for `std_pattern_match` — mirrors `standardPatternLabels` in +# `ui/components/frontend/js/data_profiling/column_distribution.js`. +_STD_PATTERN_LABELS = { + "STREET_ADDR": "Street Address", + "STATE_USA": "State (USA)", + "PHONE_USA": "Phone (USA)", + "EMAIL": "Email", + "ZIP_USA": "Zip Code (USA)", + "FILE_NAME": "Filename", + "CREDIT_CARD": "Credit Card", + "DELIMITED_DATA": "Delimited Data", + "SSN": "SSN (USA)", +} + + +def _format_std_pattern(value: str | None) -> str | None: + if not value: + return None + return _STD_PATTERN_LABELS.get(value, value.replace("_", " ").title()) + + +@with_database_session +@mcp_permission("catalog") +def get_column_profile_detail( + table_group_id: str, + table_name: str, + column_name: str, + job_execution_id: str | None = None, +) -> str: + """Get the type-specific value distribution and statistics for one column from its profiling run. + + Args: + table_group_id: UUID of the table group, e.g. from `get_data_inventory`. + table_name: Table name exactly as stored in TestGen (case-sensitive). + column_name: Column name exactly as stored in TestGen (case-sensitive). + job_execution_id: UUID of a profiling run, e.g. from `list_profiling_summaries`. + When omitted, uses the column's latest complete run. + """ + tg = resolve_table_group(table_group_id) + + profiling_run_id: UUID | None = None + if job_execution_id: + profiling_run = resolve_profiling_run(job_execution_id) + if profiling_run.table_groups_id != tg.id: + raise MCPResourceNotAccessible("Profiling run", job_execution_id) + profiling_run_id = profiling_run.id + + detail = DataColumnChars.get_column_detail( + table_groups_id=tg.id, + table_name=table_name, + column_name=column_name, + profiling_run_id=profiling_run_id, + ) + if detail is None: + raise MCPResourceNotAccessible("Column", column_name) + + if detail.profile_run_id is None: + if job_execution_id: + raise MCPUserError( + f"Profiling run `{job_execution_id}` did not include column `{column_name}`." + ) + raise MCPUserError( + f"Column `{column_name}` has not been profiled yet. " + "Run profiling for the table group first." + ) + + if detail.profile_run_status in ("Running", "Error", "Cancelled"): + _raise_run_not_ready(detail) + + perms = get_project_permissions() + payload = dataclasses.asdict(detail) + if tg.project_code not in perms.codes_allowed_to("view_pii") and detail.pii_flag: + mask_profiling_pii(payload, {detail.column_name}) + + return _render_column_profile_detail(payload) + + +def _raise_run_not_ready(detail: ColumnProfileDetail) -> None: + """Reject when the resolved profiling run is in `Running` or `Error` state. + + Surface the run id, status, started/ended timestamps, and `log_message` (Error only) + in the raised error so the LLM knows what to suggest next. + """ + je = detail.profile_run_je_id + status = detail.profile_run_status + started = detail.profile_run_started_at + ended = detail.profile_run_ended_at + started_label = started.strftime("%Y-%m-%d %H:%M UTC") if started else "—" + ended_label = ended.strftime("%Y-%m-%d %H:%M UTC") if ended else "—" + lines = [ + f"Profiling run `{je}` is in `{status}` state — no profile detail available.", + f"Started: {started_label}. Ended: {ended_label}.", + ] + if status == "Error" and detail.profile_run_log_message: + lines.append(f"Error: {detail.profile_run_log_message}") + raise MCPUserError("\n".join(lines)) + + +def _render_column_profile_detail(p: dict) -> str: + """Render a column profile detail payload as grouped Markdown sections.""" + doc = MdDoc() + fq_name = f"{p['schema_name']}.{p['table_name']}" if p["schema_name"] else p["table_name"] + doc.heading(1, f"Column Profile: `{p['column_name']}` in `{fq_name}`") + + general_type = p.get("general_type") + + # Run identity + L1 header fields + doc.field("Profiling Run", p["profile_run_je_id"], code=True) + doc.field("Profiled at", p["profile_run_started_at"]) + doc.field("General Type", _format_general_type(general_type)) + doc.field("Data Type", p["db_data_type"]) + doc.field("Semantic Data Type", p["functional_data_type"]) + if p.get("datatype_suggestion"): + doc.field("Suggested Data Type", p["datatype_suggestion"]) + doc.field("PII", _format_pii(p.get("pii_flag"))) + doc.field("Critical Data Element", p.get("critical_data_element") or False) + doc.field("Profiling Score", friendly_score(p.get("dq_score_profiling"))) + doc.field("Testing Score", friendly_score(p.get("dq_score_testing"))) + + if not p.get("query_error"): + doc.field("Hygiene Issues (confirmed)", p.get("hygiene_issue_count", 0)) + + # Type-specific dispatch (T and unknown fall through to common-counts only) + if general_type == "A": + _render_alpha_block(doc, p) + elif general_type == "N": + _render_numeric_block(doc, p) + elif general_type == "D": + _render_date_block(doc, p) + elif general_type == "B": + _render_boolean_block(doc, p) + else: + _render_unknown_block(doc, p) + else: + doc.heading(2, "Profiling Error") + doc.text(p["query_error"]) + + return doc.render() + + +_FIELD_GENERAL_TYPE_LABELS = { + "A": "Alpha", + "B": "Boolean", + "D": "Date", + "N": "Numeric", + "T": "Time", + "X": "Other", +} + + +def _format_general_type(value: str) -> str: + return _FIELD_GENERAL_TYPE_LABELS.get(value or "X") + + +def _render_counts(doc: MdDoc, p: dict) -> None: + doc.heading(2, "Counts") + doc.field("Row Count", p.get("record_ct")) + doc.field("Value Count", p.get("value_ct")) + doc.field("Distinct Values", p.get("distinct_value_ct")) + doc.field("Null", p.get("null_value_ct")) + doc.field("Dummy Values", p.get("filled_value_ct")) + doc.field("Zero Values", p.get("zero_value_ct")) + + +def _render_alpha_block(doc: MdDoc, p: dict) -> None: + _render_counts(doc, p) + doc.field("Zero Length", p.get("zero_length_ct")) + + doc.heading(2, "Length") + doc.field("Minimum Length", p.get("min_length")) + doc.field("Maximum Length", p.get("max_length")) + doc.field("Average Length", p.get("avg_length")) + + doc.heading(2, "Text Range") + doc.field("Minimum Text", p.get("min_text")) + doc.field("Maximum Text", p.get("max_text")) + + doc.heading(2, "Patterns") + doc.field("Standard Pattern Match", _format_std_pattern(p.get("std_pattern_match"))) + doc.field("Distinct Patterns", p.get("distinct_pattern_ct")) + doc.field("Frequent Patterns", p.get("top_patterns")) + doc.field("Frequent Values", p.get("top_freq_values")) + doc.field("Distinct Standard Values", p.get("distinct_std_value_ct")) + + doc.heading(2, "Case & Composition") + doc.field("Upper Case", p.get("upper_case_ct")) + doc.field("Lower Case", p.get("lower_case_ct")) + doc.field("Mixed Case", p.get("mixed_case_ct")) + doc.field("Non-Alpha", p.get("non_alpha_ct")) + doc.field("Includes Digits", p.get("includes_digit_ct")) + doc.field("Numeric Values", p.get("numeric_ct")) + doc.field("Date Values", p.get("date_ct")) + doc.field("Quoted Values", p.get("quoted_value_ct")) + doc.field("Leading Spaces", p.get("lead_space_ct")) + doc.field("Embedded Spaces", p.get("embedded_space_ct")) + doc.field("Average Embedded Spaces", p.get("avg_embedded_spaces")) + + +def _render_numeric_block(doc: MdDoc, p: dict) -> None: + _render_counts(doc, p) + + doc.heading(2, "Distribution") + doc.field("Minimum Value", p.get("min_value")) + doc.field("Minimum Value > 0", p.get("min_value_over_0")) + doc.field("Maximum Value", p.get("max_value")) + doc.field("Average Value", p.get("avg_value")) + doc.field("Standard Deviation", p.get("stdev_value")) + + doc.heading(2, "Percentiles") + doc.field("25th Percentile", p.get("percentile_25")) + doc.field("Median Value", p.get("percentile_50")) + doc.field("75th Percentile", p.get("percentile_75")) + + +def _render_date_block(doc: MdDoc, p: dict) -> None: + _render_counts(doc, p) + + doc.heading(2, "Date Range") + doc.field("Minimum Date", p.get("min_date")) + doc.field("Maximum Date", p.get("max_date")) + + doc.heading(2, "Age Buckets") + doc.field("Before 1 Year", p.get("before_1yr_date_ct")) + doc.field("Before 5 Years", p.get("before_5yr_date_ct")) + doc.field("Before 20 Years", p.get("before_20yr_date_ct")) + doc.field("Within 1 Year", p.get("within_1yr_date_ct")) + doc.field("Within 1 Month", p.get("within_1mo_date_ct")) + doc.field("Future Dates", p.get("future_date_ct")) + + +def _render_boolean_block(doc: MdDoc, p: dict) -> None: + _render_counts(doc, p) + + doc.heading(2, "Boolean Distribution") + true_ct = p.get("boolean_true_ct") or 0 + value_ct = p.get("value_ct") or 0 + false_ct = max(value_ct - true_ct, 0) + doc.field("True", true_ct) + doc.field("False", false_ct) + + +def _render_unknown_block(doc: MdDoc, p: dict) -> None: + _render_counts(doc, p) diff --git a/testgen/mcp/tools/reference.py b/testgen/mcp/tools/reference.py index abaa4d17..d9aa75bc 100644 --- a/testgen/mcp/tools/reference.py +++ b/testgen/mcp/tools/reference.py @@ -105,6 +105,131 @@ def hygiene_issue_types_resource() -> str: return doc.render() +def column_profile_fields_resource() -> str: + """Reference for column-profile fields by general_type, with PII redaction notes.""" + return """\ +# TestGen Column Profile Fields Reference + +Column profiling stores ~70 statistics per column. The fields populated +depend on the column's `General Type` (Alpha / Numeric / Date / Boolean / Other). The +`get_column_profile_detail` tool emits only the fields relevant to a column's type — use +this reference to interpret what each field measures. + +## All Column Types + +These fields are populated for every successfully-profiled column. + +### Header +- **Profiling Run** — `job_execution_id` of the profiling run the rest of the fields come from. +- **Profiled at** — Timestamp when the profiling run started (`YYYY-MM-DD HH:MM UTC`). +- **General Type** — Broad category: `Alpha`, `Numeric`, `Date`, `Boolean`, `Time`, or `Other`. +- **Data Type** — Native DB type as reported by the source (e.g. `varchar(50)`, `numeric(18,4)`). +- **Semantic Data Type** — TestGen's functional classification (e.g. `Person Given Name`, `Currency`, `Datetime-Created`). +- **Suggested Data Type** — Suggested narrower DB type given observed values (e.g. `VARCHAR(20)`, `INTEGER`). Omitted when no suggestion applies. +- **PII** — `No` when the column has no PII flag; `Yes` when manually flagged; otherwise `Yes ( Risk[ - ][ / ])` — Risk is `High`, `Moderate`, or `Low`; Category is `ID`, `Name`, `Demographic`, or `Contact`; Detail is a subtype (e.g. `Email`, `Passport`) when present. +- **Critical Data Element** — `Yes` if the column is flagged as critical (directly or via its parent table), `No` otherwise. +- **Profiling Score** — Aggregated profiling-derived quality score, 0-100. +- **Testing Score** — Aggregated testing-derived quality score, 0-100. +- **Hygiene Issues (confirmed)** — Confirmed hygiene issues against this column (count). Omitted when the column has a profiling error. + +### Counts +- **Row Count** — Total rows in the table (count, integer). +- **Value Count** — Non-null values in this column (count, integer). +- **Distinct Values** — Distinct non-null values (count, integer). +- **Null** — Null values (count, integer). +- **Dummy Values** — Dummy / placeholder values like `'?'`, `'-'`, `'unknown'` (count, integer). +- **Zero Values** — Exact-zero or `'0'`-string values (count, integer). Populated for numeric and alpha columns. + +## Alpha (text) Columns + +Populated when `General Type == "Alpha"`. + +### Length +- **Minimum Length** — Shortest string length (chars). +- **Maximum Length** — Longest string length (chars). +- **Average Length** — Average string length (chars, float). + +### Text Range +- **Minimum Text** — Lexicographic minimum value (raw string; **PII-redactable**). +- **Maximum Text** — Lexicographic maximum value (raw string; **PII-redactable**). + +### Patterns +- **Standard Pattern Match** — Recognized standard pattern when applicable (`Email`, `Phone (USA)`, + `Street Address`, `State (USA)`, `Zip Code (USA)`, `Filename`, `Credit Card`, `Delimited Data`, `SSN (USA)`). +- **Distinct Patterns** — Distinct character-class patterns observed (count). +- **Frequent Patterns** — Top patterns and counts, pipe-separated. +- **Frequent Values** — Top frequent raw values and counts (raw strings; **PII-redactable**). +- **Distinct Standard Values** — Distinct values after standardization (count). + +### Case & Composition +- **Upper Case / Lower Case / Mixed Case / Non-Alpha** — Case-distribution counts. +- **Includes Digits** — Values containing at least one digit (count). +- **Numeric Values** — Values parseable as numeric (count). +- **Date Values** — Values parseable as a date (count). +- **Quoted Values** — Values wrapped in quotes (count). +- **Leading Spaces** — Values with leading whitespace (count). +- **Embedded Spaces** — Values with internal whitespace (count). +- **Average Embedded Spaces** — Average embedded-space count per value (float). +- **Zero Length** — Empty strings (count). + +## Numeric Columns + +Populated when `General Type == "Numeric"`. + +### Distribution +- **Minimum Value** — Minimum numeric value (raw value; **PII-redactable**). +- **Minimum Value > 0** — Minimum value strictly greater than zero (**PII-redactable**). +- **Maximum Value** — Maximum numeric value (**PII-redactable**). +- **Average Value** — Arithmetic mean. +- **Standard Deviation** — Standard deviation. + +### Percentiles +- **25th Percentile** — 25th percentile (Q1). +- **Median Value** — Median (Q2 / 50th percentile). +- **75th Percentile** — 75th percentile (Q3). + +## Date Columns + +Populated when `General Type == "Date"`. + +### Date Range +- **Minimum Date** — Minimum timestamp (**PII-redactable**). +- **Maximum Date** — Maximum timestamp (**PII-redactable**). + +### Age Buckets +- **Before 1 Year** — Values older than 1 year from profiling date (count). +- **Before 5 Years** — Values older than 5 years (count). +- **Before 20 Years** — Values older than 20 years (count). +- **Within 1 Year** — Values within the past year (count). +- **Within 1 Month** — Values within the past month (count). +- **Future Dates** — Values dated after the profiling date (count). + +## Boolean Columns + +Populated when `General Type == "Boolean"`. + +- **True** — Rows where the value is true (count). +- **False** — Rows where the value is false (count, derived as `Value Count - True`). + +## PII Redaction + +When a column is flagged as PII AND the caller's role lacks permission to view PII on the column's +project, the following raw-value fields render as `[PII Redacted]`: + +- Frequent Values +- Minimum Text +- Maximum Text +- Minimum Value +- Minimum Value > 0 +- Maximum Value +- Minimum Date +- Maximum Date + +Aggregates, counts, `Frequent Patterns`, and `Standard Pattern Match` are never redacted — they're +distribution-level signals that don't expose individual rows. +""" + + def glossary_resource() -> str: """Glossary of TestGen concepts, entity hierarchy, result statuses, and quality dimensions.""" return """\ diff --git a/tests/unit/mcp/test_model_data_column.py b/tests/unit/mcp/test_model_data_column.py new file mode 100644 index 00000000..102c811f --- /dev/null +++ b/tests/unit/mcp/test_model_data_column.py @@ -0,0 +1,240 @@ +from datetime import datetime +from unittest.mock import patch +from uuid import uuid4 + +from testgen.common.models.data_column import ColumnProfileDetail, DataColumnChars + + +def _detail_row(**overrides) -> dict: + """Build a dict matching every ColumnProfileDetail field.""" + base = { + # Identity + "column_name": "customer_name", + "table_name": "customers", + "schema_name": "demo", + # Types & metadata + "general_type": "A", + "column_type": "varchar(50)", + "db_data_type": "varchar(50)", + "functional_data_type": "Person Given Name", + "datatype_suggestion": "VARCHAR(20)", + "functional_table_type": None, + "pii_flag": "B/NAME/Individual", + "critical_data_element": False, + # Counts + "record_ct": 500, + "value_ct": 500, + "distinct_value_ct": 260, + "null_value_ct": 0, + "filled_value_ct": 0, + "zero_value_ct": 0, + # Alpha + "min_length": 3, + "max_length": 50, + "avg_length": 12.4, + "min_text": "Aaron", + "max_text": "Zoey", + "top_freq_values": "| Mary | 12\n| John | 10", + "top_patterns": "10 | A(5) | 8 | A(6)", + "distinct_std_value_ct": 250, + "distinct_pattern_ct": 35, + "std_pattern_match": None, + "mixed_case_ct": 100, + "lower_case_ct": 350, + "upper_case_ct": 50, + "non_alpha_ct": 0, + "includes_digit_ct": 0, + "numeric_ct": 0, + "date_ct": 0, + "quoted_value_ct": 0, + "lead_space_ct": 0, + "embedded_space_ct": 0, + "avg_embedded_spaces": 0.0, + "zero_length_ct": 0, + # Numeric (None for an alpha column) + "min_value": None, + "min_value_over_0": None, + "max_value": None, + "avg_value": None, + "stdev_value": None, + "percentile_25": None, + "percentile_50": None, + "percentile_75": None, + # Date + "min_date": None, + "max_date": None, + "before_1yr_date_ct": None, + "before_5yr_date_ct": None, + "before_20yr_date_ct": None, + "within_1yr_date_ct": None, + "within_1mo_date_ct": None, + "future_date_ct": None, + # Boolean + "boolean_true_ct": None, + # Per-column profiling failure + "query_error": None, + # Scores & hygiene + "dq_score_profiling": 100.0, + "dq_score_testing": 98.5, + "hygiene_issue_count": 1, + # Run identity + "profile_run_id": uuid4(), + "profile_run_je_id": uuid4(), + "profile_run_status": "Complete", + "profile_run_started_at": datetime(2026, 5, 1, 12, 0, 0), + "profile_run_ended_at": datetime(2026, 5, 1, 12, 5, 0), + "profile_run_log_message": None, + } + base.update(overrides) + return base + + +@patch("testgen.common.models.data_column.get_current_session") +def test_get_column_detail_returns_dataclass_when_row_exists(session_mock): + row = _detail_row() + session_mock.return_value.execute.return_value.mappings.return_value.first.return_value = row + + result = DataColumnChars.get_column_detail( + table_groups_id=uuid4(), + table_name="customers", + column_name="customer_name", + ) + + assert isinstance(result, ColumnProfileDetail) + assert result.column_name == "customer_name" + assert result.general_type == "A" + assert result.min_text == "Aaron" + assert result.profile_run_status == "Complete" + assert result.hygiene_issue_count == 1 + + +@patch("testgen.common.models.data_column.get_current_session") +def test_get_column_detail_returns_none_when_missing(session_mock): + session_mock.return_value.execute.return_value.mappings.return_value.first.return_value = None + + result = DataColumnChars.get_column_detail( + table_groups_id=uuid4(), + table_name="customers", + column_name="ghost_column", + ) + + assert result is None + + +@patch("testgen.common.models.data_column.get_current_session") +def test_get_column_detail_numeric_column_carries_numeric_fields(session_mock): + row = _detail_row( + column_name="amount", + general_type="N", + column_type="numeric(18,4)", + db_data_type="numeric", + functional_data_type="Currency", + pii_flag=None, + # Numeric stats populated; alpha fields naturally None at the DB level for numeric columns + min_value=0.0, + min_value_over_0=0.01, + max_value=99999.99, + avg_value=125.34, + stdev_value=42.1, + percentile_25=50.0, + percentile_50=100.0, + percentile_75=200.0, + # Alpha fields cleared for realism + min_text=None, + max_text=None, + top_freq_values=None, + top_patterns=None, + ) + session_mock.return_value.execute.return_value.mappings.return_value.first.return_value = row + + result = DataColumnChars.get_column_detail( + table_groups_id=uuid4(), table_name="orders", column_name="amount" + ) + + assert result.general_type == "N" + assert result.min_value == 0.0 + assert result.percentile_50 == 100.0 + assert result.min_text is None + + +@patch("testgen.common.models.data_column.get_current_session") +def test_get_column_detail_date_column_carries_date_fields(session_mock): + row = _detail_row( + column_name="created_at", + general_type="D", + functional_data_type="Datetime-Created", + min_date=datetime(2024, 1, 1, 0, 0, 0), + max_date=datetime(2026, 4, 30, 23, 59, 59), + before_1yr_date_ct=10000, + before_5yr_date_ct=2000, + before_20yr_date_ct=0, + within_1yr_date_ct=40000, + within_1mo_date_ct=5000, + future_date_ct=0, + ) + session_mock.return_value.execute.return_value.mappings.return_value.first.return_value = row + + result = DataColumnChars.get_column_detail( + table_groups_id=uuid4(), table_name="orders", column_name="created_at" + ) + + assert result.general_type == "D" + assert result.min_date == datetime(2024, 1, 1, 0, 0, 0) + assert result.within_1yr_date_ct == 40000 + + +@patch("testgen.common.models.data_column.get_current_session") +def test_get_column_detail_boolean_column_carries_true_count(session_mock): + row = _detail_row( + column_name="is_active", + general_type="B", + functional_data_type="Boolean", + boolean_true_ct=420, + value_ct=500, + ) + session_mock.return_value.execute.return_value.mappings.return_value.first.return_value = row + + result = DataColumnChars.get_column_detail( + table_groups_id=uuid4(), table_name="users", column_name="is_active" + ) + + assert result.general_type == "B" + assert result.boolean_true_ct == 420 + + +@patch("testgen.common.models.data_column.get_current_session") +def test_get_column_detail_pinned_profiling_run_id_appears_in_query(session_mock): + """When profiling_run_id is supplied, the rendered query references that pinned id.""" + pinned_id = uuid4() + session_mock.return_value.execute.return_value.mappings.return_value.first.return_value = None + + DataColumnChars.get_column_detail( + table_groups_id=uuid4(), + table_name="customers", + column_name="customer_name", + profiling_run_id=pinned_id, + ) + + # The query passed to execute() should reference the pinned id literally. + call_args = session_mock.return_value.execute.call_args + query = call_args[0][0] + sql_str = str(query.compile(compile_kwargs={"literal_binds": True})) + # SQLAlchemy renders UUID literal binds without dashes (.hex form). + assert pinned_id.hex in sql_str or str(pinned_id) in sql_str + + +@patch("testgen.common.models.data_column.get_current_session") +def test_get_column_detail_no_pin_uses_last_complete_profile_run_id(session_mock): + """Without a pin, the join should reference the column's last_complete_profile_run_id column.""" + session_mock.return_value.execute.return_value.mappings.return_value.first.return_value = None + + DataColumnChars.get_column_detail( + table_groups_id=uuid4(), + table_name="customers", + column_name="customer_name", + ) + + call_args = session_mock.return_value.execute.call_args + query = call_args[0][0] + sql_str = str(query.compile(compile_kwargs={"literal_binds": True})) + assert "last_complete_profile_run_id" in sql_str diff --git a/tests/unit/mcp/test_tools_common.py b/tests/unit/mcp/test_tools_common.py index 0e5e1b03..4a9eb432 100644 --- a/tests/unit/mcp/test_tools_common.py +++ b/tests/unit/mcp/test_tools_common.py @@ -1,12 +1,12 @@ from unittest.mock import MagicMock, patch -from uuid import UUID +from uuid import UUID, uuid4 import pytest from testgen.common.enums import ImpactDimension, QualityDimension from testgen.common.models.hygiene_issue import Disposition, IssueLikelihood, PiiRisk from testgen.common.models.test_result import TestResultStatus -from testgen.mcp.exceptions import MCPUserError +from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.tools.common import ( format_disposition, parse_disposition, @@ -17,6 +17,7 @@ parse_result_status, parse_uuid, resolve_issue_type, + resolve_profiling_run, validate_limit, validate_page, ) @@ -279,3 +280,53 @@ def test_resolve_issue_type_not_found_raises_with_resource_hint(): with pytest.raises(MCPUserError, match="Unknown hygiene issue type") as exc_info: resolve_issue_type("Made-Up Type") assert "testgen://hygiene-issue-types" in str(exc_info.value) + + +# --- resolve_profiling_run --- + + +def _mock_perms(allowed_projects=("demo",)): + perms = MagicMock() + perms.has_access.side_effect = lambda code: code in allowed_projects + return perms + + +@patch("testgen.mcp.tools.common.get_project_permissions") +@patch("testgen.mcp.tools.common.ProfilingRun") +def test_resolve_profiling_run_happy_path(mock_pr_cls, mock_get_perms, db_session_mock): + run = MagicMock() + run.project_code = "demo" + mock_pr_cls.get_by_id_or_job.return_value = run + mock_get_perms.return_value = _mock_perms(allowed_projects=("demo",)) + + result = resolve_profiling_run(str(uuid4())) + + assert result is run + + +@patch("testgen.mcp.tools.common.get_project_permissions") +@patch("testgen.mcp.tools.common.ProfilingRun") +def test_resolve_profiling_run_unknown_run_id(mock_pr_cls, mock_get_perms, db_session_mock): + mock_pr_cls.get_by_id_or_job.return_value = None + mock_get_perms.return_value = _mock_perms() + + with pytest.raises(MCPResourceNotAccessible, match=r"Profiling run .* not found or not accessible"): + resolve_profiling_run(str(uuid4())) + + +@patch("testgen.mcp.tools.common.get_project_permissions") +@patch("testgen.mcp.tools.common.ProfilingRun") +def test_resolve_profiling_run_inaccessible_project(mock_pr_cls, mock_get_perms, db_session_mock): + """Run exists but caller can't access its project — same unified error as unknown run.""" + run = MagicMock() + run.project_code = "forbidden" + mock_pr_cls.get_by_id_or_job.return_value = run + mock_get_perms.return_value = _mock_perms(allowed_projects=("demo",)) + + with pytest.raises(MCPResourceNotAccessible, match=r"Profiling run .* not found or not accessible"): + resolve_profiling_run(str(uuid4())) + + +def test_resolve_profiling_run_invalid_uuid(): + with pytest.raises(MCPUserError, match="Invalid job_execution_id"): + resolve_profiling_run("not-a-uuid") diff --git a/tests/unit/mcp/test_tools_profiling.py b/tests/unit/mcp/test_tools_profiling.py index 9415933b..5a4ec01f 100644 --- a/tests/unit/mcp/test_tools_profiling.py +++ b/tests/unit/mcp/test_tools_profiling.py @@ -1,9 +1,11 @@ +from datetime import datetime from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest -from testgen.common.models.data_column import ColumnProfileSummary +from testgen.common.models.data_column import ColumnProfileDetail, ColumnProfileSummary +from testgen.common.pii_masking import PII_REDACTED from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.permissions import ProjectPermissions @@ -252,7 +254,7 @@ def test_list_column_profiles_paginates(mock_tg_cls, mock_dcc_cls, db_session_mo assert "Use `page=2` for more" in result -@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.common.ProfilingRun") @patch("testgen.mcp.tools.profiling.DataColumnChars") @patch("testgen.mcp.tools.common.TableGroup") def test_list_column_profiles_with_valid_job_execution_id( @@ -262,6 +264,7 @@ def test_list_column_profiles_with_valid_job_execution_id( pr = MagicMock() pr.id = uuid4() pr.table_groups_id = tg.id + pr.project_code = tg.project_code mock_tg_cls.get.return_value = tg mock_pr_cls.get_by_id_or_job.return_value = pr @@ -273,7 +276,7 @@ def test_list_column_profiles_with_valid_job_execution_id( assert mock_dcc_cls.list_for_table_group.call_args.kwargs["profiling_run_id"] == pr.id -@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.common.ProfilingRun") @patch("testgen.mcp.tools.common.TableGroup") def test_list_column_profiles_rejects_je_from_different_tg( mock_tg_cls, mock_pr_cls, db_session_mock, @@ -283,6 +286,7 @@ def test_list_column_profiles_rejects_je_from_different_tg( pr = MagicMock() pr.id = uuid4() pr.table_groups_id = uuid4() # different TG + pr.project_code = tg.project_code mock_tg_cls.get.return_value = tg mock_pr_cls.get_by_id_or_job.return_value = pr @@ -292,7 +296,7 @@ def test_list_column_profiles_rejects_je_from_different_tg( list_column_profiles(str(uuid4()), job_execution_id=str(uuid4())) -@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.common.ProfilingRun") @patch("testgen.mcp.tools.common.TableGroup") def test_list_column_profiles_rejects_unknown_je(mock_tg_cls, mock_pr_cls, db_session_mock): mock_tg_cls.get.return_value = _mock_table_group() @@ -336,14 +340,14 @@ def test_list_column_profiles_inaccessible_tg(mock_tg_cls, db_session_mock): @pytest.mark.parametrize( "value,expected", [ - (None, None), - ("", None), - ("MANUAL", "PII"), - ("A/ID/Passport", "PII (High Risk - ID / Passport)"), - ("B/NAME/Individual", "PII (Moderate Risk - Name / Individual)"), - ("C/CONTACT", "PII (Low Risk - Contact)"), - ("B/ID/ID", "PII (Moderate Risk - ID)"), # detail collapses when equal to type label - ("X/UNKNOWN/Detail", "PII (Moderate Risk / Detail)"), # unknown risk falls back; unknown type drops label + (None, "No"), + ("", "No"), + ("MANUAL", "Yes"), + ("A/ID/Passport", "Yes (High Risk - ID / Passport)"), + ("B/NAME/Individual", "Yes (Moderate Risk - Name / Individual)"), + ("C/CONTACT", "Yes (Low Risk - Contact)"), + ("B/ID/ID", "Yes (Moderate Risk - ID)"), # detail collapses when equal to type label + ("X/UNKNOWN/Detail", "Yes (Moderate Risk / Detail)"), # unknown risk falls back; unknown type drops label ], ) def test_format_pii(value, expected): @@ -359,12 +363,12 @@ def test_format_pii(value, expected): def test_render_row_renders_parsed_pii_label(): from testgen.mcp.tools.profiling import _render_column_profile_row row = _render_column_profile_row(_column_summary(pii_flag="B/NAME/Individual")) - assert row[5] == "PII (Moderate Risk - Name / Individual)" + assert row[5] == "Yes (Moderate Risk - Name / Individual)" -def test_render_row_falsy_pii_renders_none(): +def test_render_row_falsy_pii_renders_no(): from testgen.mcp.tools.profiling import _render_column_profile_row - assert _render_column_profile_row(_column_summary(pii_flag=None))[5] is None + assert _render_column_profile_row(_column_summary(pii_flag=None))[5] == "No" def test_render_row_cde_collapsed_to_y_or_none(): @@ -485,7 +489,7 @@ def test_list_profiling_summaries_inaccessible_tg(mock_tg_cls, db_session_mock): # list_profiling_runs # ---------------------------------------------------------------------- -from datetime import UTC, datetime +from datetime import UTC from testgen.common.models.job_execution import JobStatus @@ -674,3 +678,636 @@ def test_get_profiling_run_invalid_uuid(db_session_mock): from testgen.mcp.tools.profiling import get_profiling_run with pytest.raises(MCPUserError, match="not a valid UUID"): get_profiling_run("not-a-uuid") + + +# ---------------------------------------------------------------------- +# get_column_profile_detail +# ---------------------------------------------------------------------- + + +def _column_detail(**overrides) -> ColumnProfileDetail: + """Build a ColumnProfileDetail with sensible alpha-column defaults; override per test.""" + base: dict = { + # Identity + "column_name": "customer_name", + "table_name": "customers", + "schema_name": "demo", + # Types & metadata + "general_type": "A", + "column_type": "varchar(50)", + "db_data_type": "varchar(50)", + "functional_data_type": "Person Given Name", + "datatype_suggestion": "VARCHAR(20)", + "functional_table_type": None, + "pii_flag": None, + "critical_data_element": False, + # Counts + "record_ct": 500, + "value_ct": 500, + "distinct_value_ct": 260, + "null_value_ct": 0, + "filled_value_ct": 0, + "zero_value_ct": 0, + # Alpha + "min_length": 3, + "max_length": 50, + "avg_length": 12.4, + "min_text": "Aaron", + "max_text": "Zoey", + "top_freq_values": "| Mary | 12\n| John | 10", + "top_patterns": "10 | A(5) | 8 | A(6)", + "distinct_std_value_ct": 250, + "distinct_pattern_ct": 35, + "std_pattern_match": None, + "mixed_case_ct": 100, + "lower_case_ct": 350, + "upper_case_ct": 50, + "non_alpha_ct": 0, + "includes_digit_ct": 0, + "numeric_ct": 0, + "date_ct": 0, + "quoted_value_ct": 0, + "lead_space_ct": 0, + "embedded_space_ct": 0, + "avg_embedded_spaces": 0.0, + "zero_length_ct": 0, + # Numeric + "min_value": None, + "min_value_over_0": None, + "max_value": None, + "avg_value": None, + "stdev_value": None, + "percentile_25": None, + "percentile_50": None, + "percentile_75": None, + # Date + "min_date": None, + "max_date": None, + "before_1yr_date_ct": None, + "before_5yr_date_ct": None, + "before_20yr_date_ct": None, + "within_1yr_date_ct": None, + "within_1mo_date_ct": None, + "future_date_ct": None, + # Boolean + "boolean_true_ct": None, + # Per-column profiling failure + "query_error": None, + # Scores & hygiene + "dq_score_profiling": 95.2, + "dq_score_testing": 90.0, + "hygiene_issue_count": 2, + # Run identity + "profile_run_id": uuid4(), + "profile_run_je_id": uuid4(), + "profile_run_status": "Complete", + "profile_run_started_at": datetime(2026, 5, 1, 12, 0, 0), + "profile_run_ended_at": datetime(2026, 5, 1, 12, 5, 0), + "profile_run_log_message": None, + } + base.update(overrides) + return ColumnProfileDetail(**base) + + +# --- happy paths per general_type --- + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_alpha_renders_alpha_sections(mock_tg_cls, mock_dcc_cls, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail(general_type="A") + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "customers", "customer_name") + + assert "Column Profile" in result + assert "customer_name" in result + assert "Profiling Run" in result + # Alpha-specific sections present + assert "Length" in result + assert "Text Range" in result + assert "Patterns" in result + assert "Aaron" in result + assert "Zoey" in result + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_alpha_renders_distinct_standard_values( + mock_tg_cls, mock_dcc_cls, db_session_mock +): + """`distinct_std_value_ct` (alpha-only) renders under the Patterns section as 'Distinct Standard Values'.""" + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + general_type="A", + distinct_std_value_ct=247, + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "customers", "customer_name") + + assert "Distinct Standard Values" in result + assert "247" in result + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_numeric_renders_numeric_sections(mock_tg_cls, mock_dcc_cls, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + column_name="amount", + general_type="N", + db_data_type="numeric", + functional_data_type="Currency", + # Numeric stats + min_value=0.0, + min_value_over_0=0.01, + max_value=99999.99, + avg_value=125.34, + stdev_value=42.1, + percentile_25=50.0, + percentile_50=100.0, + percentile_75=200.0, + # Alpha fields cleared (numeric column wouldn't have these populated) + min_text=None, + max_text=None, + top_freq_values=None, + top_patterns=None, + min_length=None, + max_length=None, + avg_length=None, + std_pattern_match=None, + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "orders", "amount") + + # Numeric-specific content present + assert "Median" in result or "Percentile" in result or "percentile_50" in result.lower() + assert "99999.99" in result or "99,999.99" in result + # Alpha-only sections absent + assert "Text Range" not in result + assert "Min Text" not in result + assert "Aaron" not in result + assert "Length" not in result.replace("Avg Length", "") # rough — ensures no Length section + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_date_renders_date_sections(mock_tg_cls, mock_dcc_cls, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + column_name="created_at", + general_type="D", + db_data_type="timestamp", + functional_data_type="Datetime-Created", + min_date=datetime(2024, 1, 1, 0, 0, 0), + max_date=datetime(2026, 4, 30, 23, 59, 59), + before_1yr_date_ct=10000, + before_5yr_date_ct=2000, + before_20yr_date_ct=0, + within_1yr_date_ct=40000, + within_1mo_date_ct=5000, + future_date_ct=0, + # Alpha fields cleared + min_text=None, + max_text=None, + top_freq_values=None, + top_patterns=None, + min_length=None, + max_length=None, + avg_length=None, + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "orders", "created_at") + + # Date-specific content + assert "Within 1" in result or "Before 1" in result or "Date Range" in result + assert "2024" in result + # Alpha-only sections absent + assert "Aaron" not in result + assert "Pattern" not in result + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_boolean_renders_boolean_section(mock_tg_cls, mock_dcc_cls, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + column_name="is_active", + general_type="B", + db_data_type="boolean", + functional_data_type="Boolean", + boolean_true_ct=420, + value_ct=500, + # Alpha fields cleared + min_text=None, + max_text=None, + top_freq_values=None, + top_patterns=None, + min_length=None, + max_length=None, + avg_length=None, + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "users", "is_active") + + assert "True" in result + assert "420" in result + # Alpha-only sections absent + assert "Pattern" not in result + assert "Length" not in result.replace("Avg Length", "") + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_unknown_general_type_renders_counts_only( + mock_tg_cls, mock_dcc_cls, db_session_mock +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + column_name="payload", + general_type="X", + db_data_type="json", + functional_data_type=None, + # All type-specific fields cleared + min_text=None, + max_text=None, + top_freq_values=None, + top_patterns=None, + min_length=None, + max_length=None, + avg_length=None, + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "events", "payload") + + assert "payload" in result + assert "Counts" in result + assert "Pattern" not in result + assert "Boolean Distribution" not in result + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_general_type_t_treated_as_unknown( + mock_tg_cls, mock_dcc_cls, db_session_mock +): + """T mirrors current UI behavior — falls through to common counts only.""" + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + column_name="ts", + general_type="T", + db_data_type="time", + functional_data_type=None, + min_text=None, + max_text=None, + top_freq_values=None, + top_patterns=None, + min_length=None, + max_length=None, + avg_length=None, + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "events", "ts") + + assert "Counts" in result + assert "Date Range" not in result # not dispatched as date + + +# --- never-profiled / no-profile-for-pinned-run --- + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_never_profiled_column_rejects( + mock_tg_cls, mock_dcc_cls, db_session_mock +): + """Column row exists in data_column_chars but has no completed profiling run yet + (`last_complete_profile_run_id IS NULL`). The model returns a detail with NULL run + fields; the tool must reject rather than render an empty profile. + """ + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + profile_run_id=None, + profile_run_je_id=None, + profile_run_status=None, + profile_run_started_at=None, + profile_run_ended_at=None, + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPUserError) as exc_info: + get_column_profile_detail(str(uuid4()), "customers", "customer_name") + + msg = str(exc_info.value) + assert "customer_name" in msg + assert "not been profiled" in msg + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_pinned_run_without_column_rejects( + mock_tg_cls, mock_pr_cls, mock_dcc_cls, db_session_mock, +): + """User pins a valid run via job_execution_id, but that run has no profile_results + row for this column. Surface the pinned run id so the LLM knows what to try next. + """ + tg = _mock_table_group() + pr = MagicMock() + pr.id = uuid4() + pr.table_groups_id = tg.id + pr.project_code = tg.project_code + + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_by_id_or_job.return_value = pr + mock_dcc_cls.get_column_detail.return_value = _column_detail( + profile_run_id=None, + profile_run_je_id=None, + profile_run_status=None, + profile_run_started_at=None, + profile_run_ended_at=None, + ) + + je_id_str = str(uuid4()) + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPUserError) as exc_info: + get_column_profile_detail( + str(uuid4()), "customers", "customer_name", job_execution_id=je_id_str + ) + + msg = str(exc_info.value) + assert "customer_name" in msg + assert je_id_str in msg + + +# --- error paths --- + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_column_not_found_unified_error( + mock_tg_cls, mock_dcc_cls, db_session_mock +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = None + + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPResourceNotAccessible, match=r"Column .* not found or not accessible"): + get_column_profile_detail(str(uuid4()), "customers", "ghost_column") + + +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_inaccessible_tg(mock_tg_cls, db_session_mock): + mock_tg_cls.get.return_value = None + + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPResourceNotAccessible, match=r"Table group .* not found or not accessible"): + get_column_profile_detail(str(uuid4()), "customers", "x") + + +def test_get_column_profile_detail_invalid_tg_uuid(db_session_mock): + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPUserError, match="Invalid table_group_id"): + get_column_profile_detail("not-a-uuid", "customers", "x") + + +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_invalid_je_uuid(mock_tg_cls, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPUserError, match="Invalid job_execution_id"): + get_column_profile_detail( + str(uuid4()), "customers", "x", job_execution_id="bad" + ) + + +# --- job_execution_id pinning --- + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_pinned_run_passes_id_to_model( + mock_tg_cls, mock_pr_cls, mock_dcc_cls, db_session_mock, +): + tg = _mock_table_group() + pr = MagicMock() + pr.id = uuid4() + pr.table_groups_id = tg.id + pr.project_code = tg.project_code + + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_by_id_or_job.return_value = pr + mock_dcc_cls.get_column_detail.return_value = _column_detail() + + from testgen.mcp.tools.profiling import get_column_profile_detail + get_column_profile_detail(str(uuid4()), "customers", "customer_name", job_execution_id=str(uuid4())) + + assert mock_dcc_cls.get_column_detail.call_args.kwargs["profiling_run_id"] == pr.id + + +@patch("testgen.mcp.tools.common.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_pinned_run_from_different_tg_unified_error( + mock_tg_cls, mock_pr_cls, db_session_mock, +): + tg = _mock_table_group() + pr = MagicMock() + pr.id = uuid4() + pr.table_groups_id = uuid4() # different + pr.project_code = tg.project_code + + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_by_id_or_job.return_value = pr + + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPResourceNotAccessible, match=r"Profiling run .* not found or not accessible"): + get_column_profile_detail( + str(uuid4()), "customers", "x", job_execution_id=str(uuid4()) + ) + + +@patch("testgen.mcp.tools.common.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_pinned_run_unknown_unified_error( + mock_tg_cls, mock_pr_cls, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_pr_cls.get_by_id_or_job.return_value = None + + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPResourceNotAccessible, match=r"Profiling run .* not found or not accessible"): + get_column_profile_detail( + str(uuid4()), "customers", "x", job_execution_id=str(uuid4()) + ) + + +# --- run-status preconditions --- + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_running_run_rejects_with_status( + mock_tg_cls, mock_dcc_cls, db_session_mock +): + mock_tg_cls.get.return_value = _mock_table_group() + je_id = uuid4() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + profile_run_status="Running", + profile_run_je_id=je_id, + profile_run_ended_at=None, + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPUserError) as exc_info: + get_column_profile_detail(str(uuid4()), "customers", "customer_name") + + msg = str(exc_info.value) + assert "Running" in msg + assert str(je_id) in msg + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_error_run_includes_log_message( + mock_tg_cls, mock_dcc_cls, db_session_mock +): + mock_tg_cls.get.return_value = _mock_table_group() + je_id = uuid4() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + profile_run_status="Error", + profile_run_je_id=je_id, + profile_run_log_message="connection timed out", + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + with pytest.raises(MCPUserError) as exc_info: + get_column_profile_detail(str(uuid4()), "customers", "customer_name") + + msg = str(exc_info.value) + assert "Error" in msg + assert str(je_id) in msg + assert "connection timed out" in msg + + +# --- PII redaction --- + + +@patch("testgen.mcp.permissions._compute_project_permissions") +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_pii_column_no_view_pii_redacts( + mock_tg_cls, mock_dcc_cls, mock_compute, db_session_mock, +): + """User has 'catalog' on demo but NOT 'view_pii' → 8 raw-value fields redacted; aggregates kept.""" + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + pii_flag="B/CONTACT/Email", + column_name="customer_email", + general_type="A", + std_pattern_match="EMAIL", + min_text="aaron@example.com", + max_text="zoey@example.com", + top_freq_values="| mary@x.com | 1\n| john@x.com | 1", + ) + # No project includes view_pii — only catalog allowed + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_c"}, # role_c has 'catalog' but not 'view_pii' in test matrix + permission="catalog", + username="test_user", + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "customers", "customer_email") + + # Raw-value fields redacted + assert PII_REDACTED in result + assert "aaron@example.com" not in result + assert "zoey@example.com" not in result + assert "mary@x.com" not in result + # Aggregates / counts / std_pattern_match still visible + assert "260" in result or "Distinct" in result + assert "EMAIL" in result or "Email" in result + + +@patch("testgen.mcp.permissions._compute_project_permissions") +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_pii_column_with_view_pii_shows_values( + mock_tg_cls, mock_dcc_cls, mock_compute, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + pii_flag="B/CONTACT/Email", + column_name="customer_email", + min_text="aaron@example.com", + max_text="zoey@example.com", + ) + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, # role_a has 'view_pii' in conftest matrix? actually no — but we need a role that includes view_pii. Use role-with-view_pii via "edit" mapping. + permission="catalog", + username="test_user", + ) + # Patch the rbac mapping so role_a includes view_pii for this test + with patch("testgen.mcp.permissions.PluginHook") as mock_hook: + mock_hook.instance.return_value.rbac.get_roles_with_permission.side_effect = ( + lambda perm: ["role_a"] if perm in ("catalog", "view_pii") else [] + ) + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "customers", "customer_email") + + assert "aaron@example.com" in result + assert PII_REDACTED not in result + + +@patch("testgen.mcp.permissions._compute_project_permissions") +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_non_pii_column_never_redacts( + mock_tg_cls, mock_dcc_cls, mock_compute, db_session_mock, +): + """No pii_flag → raw values shown regardless of view_pii grant.""" + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + pii_flag=None, + min_text="Aaron", + max_text="Zoey", + ) + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_c"}, + permission="catalog", + username="test_user", + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "customers", "customer_name") + + assert "Aaron" in result + assert "Zoey" in result + assert PII_REDACTED not in result + + +# --- query_error surfacing --- + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_profile_detail_query_error_section(mock_tg_cls, mock_dcc_cls, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.get_column_detail.return_value = _column_detail( + query_error="ORA-01017: invalid username/password", + ) + + from testgen.mcp.tools.profiling import get_column_profile_detail + result = get_column_profile_detail(str(uuid4()), "customers", "customer_name") + + assert "Profiling Error" in result + assert "ORA-01017" in result diff --git a/tests/unit/mcp/test_tools_reference.py b/tests/unit/mcp/test_tools_reference.py index 55f6b508..96ead51f 100644 --- a/tests/unit/mcp/test_tools_reference.py +++ b/tests/unit/mcp/test_tools_reference.py @@ -176,3 +176,90 @@ def test_hygiene_issue_types_resource_empty(mock_type_cls, db_session_mock): result = hygiene_issue_types_resource() assert "No hygiene issue types found" in result + + +# --- column_profile_fields_resource --- + + +def test_column_profile_fields_resource_has_five_sections(): + from testgen.mcp.tools.reference import column_profile_fields_resource + + result = column_profile_fields_resource() + + assert "TestGen Column Profile Fields Reference" in result + assert "## All Column Types" in result + assert "## Alpha" in result + assert "## Numeric" in result + assert "## Date" in result + assert "## Boolean" in result + + +def test_column_profile_fields_resource_lists_all_pii_redacted_fields(): + """The footer must name every redactable field so the LLM can interpret `[PII Redacted]` markers.""" + from testgen.mcp.tools.reference import column_profile_fields_resource + + result = column_profile_fields_resource() + + # Friendly labels mirroring PROFILING_PII_FIELDS from testgen.common.pii_masking. + expected_labels = ( + "Frequent Values", + "Minimum Text", + "Maximum Text", + "Minimum Value", + "Minimum Value > 0", + "Maximum Value", + "Minimum Date", + "Maximum Date", + ) + for label in expected_labels: + assert label in result, f"Expected `{label}` to be named in the redaction note" + + +def test_column_profile_fields_resource_describes_redaction_trigger(): + from testgen.mcp.tools.reference import column_profile_fields_resource + + result = column_profile_fields_resource() + + # The redaction trigger: column is PII-flagged AND caller lacks permission to view PII. + assert "PII" in result + assert "permission to view PII" in result + + +def test_column_profile_fields_resource_describes_per_type_fields(): + """Each section should at least mention the most distinctive field for that type.""" + from testgen.mcp.tools.reference import column_profile_fields_resource + + result = column_profile_fields_resource() + + # All-types section + assert "Row Count" in result + assert "Hygiene Issues" in result + # Alpha + assert "Minimum Length" in result + assert "Frequent Values" in result + assert "Standard Pattern Match" in result + # Numeric + assert "Minimum Value" in result + assert "Median Value" in result + # Datetime + assert "Minimum Date" in result + assert "Before 1 Year" in result + # Boolean + assert "## Boolean Columns" in result + + +# --- server instructions reference the new resource --- + + +def test_server_instructions_reference_column_profile_fields_resource(): + """The LLM relies on SERVER_INSTRUCTIONS to learn which resources to consult. + + The new resource must be named alongside test-types and hygiene-issue-types so + the LLM knows when to look up column-profile field semantics. + """ + from testgen.mcp.server import SERVER_INSTRUCTIONS + + assert "testgen://column-profile-fields" in SERVER_INSTRUCTIONS + # Sanity check the existing references are still present. + assert "testgen://test-types" in SERVER_INSTRUCTIONS + assert "testgen://hygiene-issue-types" in SERVER_INSTRUCTIONS From e02afc35b9bbf07f100084ffd1e7dc06186d4b38 Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Tue, 12 May 2026 10:34:42 -0400 Subject: [PATCH 07/58] refactor(mcp): apply TG-1054 review feedback Address review comments on MR !511: - create_test takes a single ``fields: dict`` instead of explicit kwargs; no field can bypass the editable_fields() whitelist. - editable_fields() gates column_name (column / custom scopes) and impact_dimension (custom / referential scopes) per UI logic. - Extract ``validate_custom_query`` to ``testgen/common/custom_test_validation.py``; used by both UI's validate_test and MCP's validate_custom_test. - validate_custom_test now wraps user SQL in ``SELECT COUNT(*) FROM (...) ERR_TABLE`` for the count and applies a flavor-aware ``LIMIT 1`` for the preview row. Matches the test runtime's wrapping pattern (correctness parity, blocks DDL/DML). - Output wording uses "rows matching the failure criteria" throughout; PII footer drops the ``view_pii`` jargon. - bulk_update_tests uses ``result.rowcount`` instead of materialising every UUID via ``.returning(id).all()``. - UI: validate_test calls the shared helper; impact_dimension gate simplifies to ``test_scope in ('custom', 'referential')``. Co-Authored-By: Claude Opus 4.7 (1M context) --- testgen/common/custom_test_validation.py | 75 +++++++ testgen/common/models/test_definition.py | 12 +- testgen/mcp/tools/test_definitions.py | 119 ++++------ .../frontend/js/pages/test_definitions.js | 2 +- testgen/ui/views/test_definitions.py | 16 +- .../common/models/test_test_definition.py | 34 +++ .../common/test_custom_test_validation.py | 194 ++++++++++++++++ tests/unit/mcp/test_tools_test_definitions.py | 207 +++++++++--------- 8 files changed, 467 insertions(+), 192 deletions(-) create mode 100644 testgen/common/custom_test_validation.py create mode 100644 tests/unit/common/test_custom_test_validation.py diff --git a/testgen/common/custom_test_validation.py b/testgen/common/custom_test_validation.py new file mode 100644 index 00000000..82de4754 --- /dev/null +++ b/testgen/common/custom_test_validation.py @@ -0,0 +1,75 @@ +"""Shared validation for custom-test SQL queries. + +Wraps user-supplied SQL in a parent ``SELECT COUNT(*) FROM () ERR_TABLE`` form +matching the test execution runtime, then runs it against the target database. Optional +preview returns the first N rows for inspection. + +Wrapping serves two purposes: +- Validation parity with runtime — a bare query that runs may still fail when wrapped. +- DDL/DML rejection — non-SELECT statements fail to parse as a subquery. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from sqlalchemy.engine import RowMapping + +from testgen.common.database.database_service import get_flavor_service, replace_params +from testgen.ui.services.database_service import fetch_from_target_db + +if TYPE_CHECKING: + from testgen.common.database.flavor.flavor_service import FlavorService + from testgen.common.models.connection import Connection + + +@dataclass +class CustomQueryResult: + """Outcome of running a wrapped custom-test SQL query.""" + + row_count: int + preview_rows: list[RowMapping] = field(default_factory=list) + + +def validate_custom_query( + connection: Connection, + schema: str, + custom_sql: str, + preview_limit: int = 0, +) -> CustomQueryResult: + """Wrap and execute a custom-test SQL query against the target DB. + + Args: + connection: Target ``Connection`` to run the query on. + schema: Schema name for ``{DATA_SCHEMA}`` substitution in the user's SQL. + custom_sql: User-supplied query. Should return rows matching the test failure criteria. + preview_limit: When > 0, also fetch up to N rows for preview (only when row_count > 0). + + Returns the failure-criteria row count and (optionally) the preview rows. DB errors + propagate as-is — the caller decides how to surface them. + """ + sql_with_schema = replace_params(custom_sql, {"DATA_SCHEMA": schema}).rstrip().rstrip(";") + flavor_service = get_flavor_service(connection.sql_flavor) + + count_sql = f"SELECT COUNT(*) FROM ({sql_with_schema}) ERR_TABLE" + count_rows = fetch_from_target_db(connection, count_sql) + row_count = int(count_rows[0][0]) if count_rows else 0 + + preview_rows: list[RowMapping] = [] + if preview_limit > 0 and row_count > 0: + prefix, suffix = _row_limit_clauses(flavor_service, preview_limit) + preview_sql = f"SELECT {prefix} * FROM ({sql_with_schema}) ERR_TABLE {suffix}".strip() + preview_rows = fetch_from_target_db(connection, preview_sql) + + return CustomQueryResult(row_count=row_count, preview_rows=preview_rows) + + +def _row_limit_clauses(flavor_service: FlavorService, n: int) -> tuple[str, str]: + """Return (prefix, suffix) for limiting a SELECT to N rows on the given flavor.""" + clause = flavor_service.row_limiting_clause + if clause == "top": + return f"TOP {n}", "" + if clause == "fetch": + return "", f"FETCH FIRST {n} ROWS ONLY" + return "", f"LIMIT {n}" diff --git a/testgen/common/models/test_definition.py b/testgen/common/models/test_definition.py index 50cd78ef..f10a6b31 100644 --- a/testgen/common/models/test_definition.py +++ b/testgen/common/models/test_definition.py @@ -445,7 +445,17 @@ def list_for_suite( def editable_fields(self, test_type: TestType) -> set[str]: """Fields a caller may set or change on this test definition under the given test type.""" - return self.EDITABLE_BASE_FIELDS | test_type.param_columns + fields = self.EDITABLE_BASE_FIELDS | test_type.param_columns + # column_name is meaningful for column-scoped tests (the column under test) and + # custom-scoped tests (a "Test Focus" label). Other scopes don't use it. + if test_type.test_scope in ("column", "custom"): + fields = fields | {"column_name"} + # impact_dimension is overridable only for user-defined-semantic scopes + # (custom-scope = user-authored SQL; referential-scope = comparison-based tests). + # Other scopes have baked-in dimensions so the override doesn't apply. + if test_type.test_scope in ("custom", "referential"): + fields = fields | {"impact_dimension"} + return fields def validate(self, test_type: TestType) -> None: """Validate the current state against the given test type. diff --git a/testgen/mcp/tools/test_definitions.py b/testgen/mcp/tools/test_definitions.py index f764999a..d95d27c4 100644 --- a/testgen/mcp/tools/test_definitions.py +++ b/testgen/mcp/tools/test_definitions.py @@ -4,6 +4,7 @@ from sqlalchemy import update +from testgen.common.custom_test_validation import validate_custom_query from testgen.common.enums import ImpactDimension, QualityDimension from testgen.common.models import get_current_session, with_database_session from testgen.common.models.connection import Connection @@ -32,7 +33,6 @@ validate_page, ) from testgen.mcp.tools.markdown import MdDoc -from testgen.ui.services.database_service import fetch_from_target_db _DOC_GROUP = DocGroup.DISCOVER @@ -397,12 +397,7 @@ def create_test( test_suite_id: str, test_type: str, table_name: str, - column_name: str | None = None, - threshold_value: str | None = None, - baseline_value: str | None = None, - severity: str | None = None, - custom_query: str | None = None, - extra_params: dict | None = None, + fields: dict | None = None, ) -> str: """Create a test in a test suite. @@ -410,14 +405,10 @@ def create_test( test_suite_id: UUID of the test suite. test_type: Test type name, e.g. ``Alpha Truncation`` or ``Custom Test``. table_name: Target table name. Case-sensitive. - column_name: Required for column-scoped test types. - threshold_value: Test threshold. - baseline_value: Baseline reference. - severity: ``Fail`` or ``Warning``. Omit to inherit the test type default. - custom_query: SQL for tests that accept a custom query. - extra_params: Additional test-type-specific parameters (e.g. ``window_days``, - ``match_column_names``, ``lower_tolerance``). Use ``list_test_types`` or - ``get_test`` on a similar test to discover supported names. + fields: Mapping of field name to value for the test's parameters and metadata + (e.g. ``threshold_value``, ``custom_query``, ``severity``, ``column_name``, + ``test_description``). Use ``list_test_types`` or ``get_test`` on a similar + test to discover what's settable for the chosen test type. """ suite = resolve_test_suite(test_suite_id) tt_code = resolve_test_type(test_type) @@ -439,33 +430,15 @@ def create_test( lock_refresh=False, last_manual_update=datetime.now(UTC), ) - explicit = { - "column_name": column_name, - "threshold_value": threshold_value, - "baseline_value": baseline_value, - "severity": severity, - "custom_query": custom_query, - } - for key, value in explicit.items(): - if value is not None: - setattr(td, key, value) - - if extra_params: - accepted = td.editable_fields(tt) - rejected = sorted(set(extra_params) - accepted) - if rejected: - raise MCPUserError( - f"These `extra_params` keys are not editable for test type `{tt_code}`: " - f"{', '.join(rejected)}." - ) - conflicts = sorted(set(extra_params) & {k for k, v in explicit.items() if v is not None}) - if conflicts: - raise MCPUserError( - f"These fields were set both as named arguments and in `extra_params`: " - f"{', '.join(conflicts)}. Pass each value only once." - ) - for key, value in extra_params.items(): - setattr(td, key, value) + + fields = fields or {} + accepted = td.editable_fields(tt) + rejected = sorted(set(fields) - accepted) + if rejected: + bullets = "\n".join(f"- `{key}`: not editable for test type `{tt_code}`" for key in rejected) + raise MCPUserError(f"Test definition creation rejected. No changes saved.\n\n{bullets}") + for key, value in fields.items(): + setattr(td, key, value) try: td.validate(tt) @@ -492,9 +465,8 @@ def update_test(test_definition_id: str, fields: dict) -> str: Args: test_definition_id: UUID of the test definition. - fields: Mapping of field name to new value. Accepts the test type's parameter - columns (use ``get_test`` to see the current values and supported fields) - plus ``test_active``, ``severity``, ``lock_refresh``, ``flagged``. + fields: Mapping of field name to new value. Use ``get_test`` to see the current + values and which fields are settable for the test's type. """ td = resolve_test_definition(test_definition_id) tt = TestType.get(td.test_type) @@ -546,14 +518,20 @@ def _format_diff(value: object) -> str | None: def validate_custom_test(test_suite_id: str, custom_sql: str) -> str: """Dry-run a custom test SQL query against the test suite's parent connection. + The query should return rows matching the test failure criteria — returning no rows + means the test passes; returning any rows means it fails. + Args: test_suite_id: UUID of the test suite whose connection the SQL runs against. - custom_sql: SQL query to dry-run. + custom_sql: SQL query returning failure-criteria rows. """ suite = resolve_test_suite(test_suite_id) connection = Connection.get_by_table_group(suite.table_groups_id) if connection is None: raise MCPUserError("No connection configured for this test suite's table group.") + table_group = TableGroup.get(suite.table_groups_id) + if table_group is None: + raise MCPUserError("Test suite is not associated with a table group.") perms = get_project_permissions() can_view_pii = suite.project_code in perms.codes_allowed_to("view_pii") @@ -562,7 +540,9 @@ def validate_custom_test(test_suite_id: str, custom_sql: str) -> str: doc.heading(1, "Custom test dry-run") try: - rows = fetch_from_target_db(connection, custom_sql) + result = validate_custom_query( + connection, table_group.table_group_schema, custom_sql, preview_limit=1, + ) except Exception as e: # broad catch: the DB error message IS the user-facing signal doc.text(f"**SQL did not execute.** Query was not committed against `{connection.connection_name}`.") message = str(e.args[0]) if e.args else str(e) @@ -570,37 +550,38 @@ def validate_custom_test(test_suite_id: str, custom_sql: str) -> str: doc.code_block(message) return doc.render() - row_count = len(rows) flavor = connection.sql_flavor_code or connection.sql_flavor or "target database" doc.text( f"**SQL ran successfully** against `{connection.connection_name}` ({flavor})." ) - if row_count == 0: - doc.text("**Would pass:** ✓ — query returned 0 error rows.") + if result.row_count == 0: + doc.text("**Would pass:** ✓ — query returned 0 rows matching the failure criteria.") doc.text( "_If saved as a CUSTOM test, this would currently pass: the test fails when any " - "error rows are returned, and there are none._" + "rows match the failure criteria, and there are none._" ) return doc.render() - doc.text(f"**Would fail:** ✗ — query returned {row_count} error row(s).") - doc.heading(2, "Source data preview (first row)") - first = rows[0] - columns = list(first.keys()) - if can_view_pii: - values = [first[c] for c in columns] - else: - values = ["[redacted]"] * len(columns) - doc.table(columns, [values]) doc.text( - "_If saved as a CUSTOM test, this would currently fail because the SQL returned error " - "rows. Refine the query if some of those rows are false positives._" + f"**Would fail:** ✗ — query returned {result.row_count} row(s) matching the failure criteria." + ) + if result.preview_rows: + doc.heading(2, "Source data preview (first row)") + first = result.preview_rows[0] + columns = list(first.keys()) + if can_view_pii: + values = [first[c] for c in columns] + else: + values = ["[redacted]"] * len(columns) + doc.table(columns, [values]) + doc.text( + "_If saved as a CUSTOM test, this would currently fail because the SQL returned rows " + "matching the test failure criteria. Refine the query if some of those rows are false positives._" ) if not can_view_pii: doc.text( - "_PII redacted: caller does not have `view_pii` on this project. Column names shown " - "so the LLM can iterate on shape; row values are masked._" + "_PII redacted: caller does not have permissions to view PII on this project._" ) return doc.render() @@ -642,15 +623,9 @@ def bulk_update_tests( if tt_code: where_clauses.append(TestDefinition.test_type == tt_code) - stmt = ( - update(TestDefinition) - .where(*where_clauses) - .values(**values) - .returning(TestDefinition.id) - ) + stmt = update(TestDefinition).where(*where_clauses).values(**values) session = get_current_session() - affected = session.execute(stmt).all() - count = len(affected) + count = session.execute(stmt).rowcount verb = "Enabled" if target else "Disabled" filters = [] diff --git a/testgen/ui/components/frontend/js/pages/test_definitions.js b/testgen/ui/components/frontend/js/pages/test_definitions.js index d715178e..58c707f5 100644 --- a/testgen/ui/components/frontend/js/pages/test_definitions.js +++ b/testgen/ui/components/frontend/js/pages/test_definitions.js @@ -1071,7 +1071,7 @@ const TestDefFormContent = ({ formValues, tableColumns, testSuite, validateResul { label: 'Regularity', value: 'Regularity' }, { label: 'Usability', value: 'Usability' }, ]; - const showImpactDimensionOverride = testType === 'CUSTOM' || testType === 'Condition_Flag' || testScope === 'referential'; + const showImpactDimensionOverride = ['custom', 'referential'].includes(testScope); const tableNameOptions = [ ...new Set((tableColumns ?? []).map(c => c.table_name).filter(Boolean)) diff --git a/testgen/ui/views/test_definitions.py b/testgen/ui/views/test_definitions.py index e0d8d267..3fe28055 100644 --- a/testgen/ui/views/test_definitions.py +++ b/testgen/ui/views/test_definitions.py @@ -8,7 +8,8 @@ from sqlalchemy import and_, asc, case, desc, func, or_, tuple_ from testgen.common import date_service -from testgen.common.database.database_service import get_flavor_service, replace_params +from testgen.common.custom_test_validation import validate_custom_query +from testgen.common.database.database_service import get_flavor_service from testgen.common.models import with_database_session from testgen.common.models.connection import Connection from testgen.common.models.job_execution import JobExecution @@ -967,15 +968,6 @@ def validate_test(test_definition: dict, table_group: TableGroupMinimal) -> None ) FROM {quote}{schema}{quote}.{quote}{table_name}{quote}; """ + fetch_from_target_db(connection, query) else: - query = replace_params( - f""" - SELECT COUNT(*) - FROM ( - {test_definition["custom_query"]} - ) TEST - """, - {"DATA_SCHEMA": schema}, - ) - - fetch_from_target_db(connection, query) + validate_custom_query(connection, schema, test_definition["custom_query"]) diff --git a/tests/unit/common/models/test_test_definition.py b/tests/unit/common/models/test_test_definition.py index 929c7eee..8d9ffb12 100644 --- a/tests/unit/common/models/test_test_definition.py +++ b/tests/unit/common/models/test_test_definition.py @@ -96,6 +96,40 @@ def test_editable_fields_includes_param_columns(): assert {"threshold_value", "baseline_value"} <= accepted +def test_editable_fields_includes_impact_dimension_only_for_custom_or_referential_scope(): + """impact_dimension is overridable only for user-defined-semantic scopes.""" + td = make_td() + + custom_tt = make_test_type(scope="custom", param_columns={"custom_query"}) + assert "impact_dimension" in td.editable_fields(custom_tt) + + referential_tt = make_test_type(scope="referential", param_columns={"match_column_names"}) + assert "impact_dimension" in td.editable_fields(referential_tt) + + column_tt = make_test_type(scope="column", param_columns={"threshold_value"}) + assert "impact_dimension" not in td.editable_fields(column_tt) + + table_tt = make_test_type(scope="table", param_columns=set()) + assert "impact_dimension" not in td.editable_fields(table_tt) + + +def test_editable_fields_includes_column_name_only_for_column_or_custom_scope(): + """column_name is meaningful for column-scope (column under test) and custom-scope (label).""" + td = make_td() + + column_tt = make_test_type(scope="column", param_columns={"threshold_value"}) + assert "column_name" in td.editable_fields(column_tt) + + custom_tt = make_test_type(scope="custom", param_columns={"custom_query"}) + assert "column_name" in td.editable_fields(custom_tt) + + table_tt = make_test_type(scope="table", param_columns=set()) + assert "column_name" not in td.editable_fields(table_tt) + + referential_tt = make_test_type(scope="referential", param_columns={"match_column_names"}) + assert "column_name" not in td.editable_fields(referential_tt) + + def test_editable_fields_does_not_leak_identity_or_internal_columns(): tt = make_test_type(param_columns={"threshold_value"}) td = make_td() diff --git a/tests/unit/common/test_custom_test_validation.py b/tests/unit/common/test_custom_test_validation.py new file mode 100644 index 00000000..78302fec --- /dev/null +++ b/tests/unit/common/test_custom_test_validation.py @@ -0,0 +1,194 @@ +"""Tests for testgen.common.custom_test_validation.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from testgen.common.custom_test_validation import ( + CustomQueryResult, + _row_limit_clauses, + validate_custom_query, +) + + +def _flavor_service(row_limiting: str = "limit") -> MagicMock: + svc = MagicMock() + svc.row_limiting_clause = row_limiting + return svc + + +def _connection(flavor: str = "postgresql") -> MagicMock: + conn = MagicMock() + conn.sql_flavor = flavor + return conn + + +# -- _row_limit_clauses ------------------------------------------------------- + + +def test_row_limit_clauses_limit_flavor(): + prefix, suffix = _row_limit_clauses(_flavor_service("limit"), 5) + assert prefix == "" + assert suffix == "LIMIT 5" + + +def test_row_limit_clauses_top_flavor(): + prefix, suffix = _row_limit_clauses(_flavor_service("top"), 5) + assert prefix == "TOP 5" + assert suffix == "" + + +def test_row_limit_clauses_fetch_flavor(): + prefix, suffix = _row_limit_clauses(_flavor_service("fetch"), 5) + assert prefix == "" + assert suffix == "FETCH FIRST 5 ROWS ONLY" + + +# -- validate_custom_query ---------------------------------------------------- + + +@patch("testgen.common.custom_test_validation.fetch_from_target_db") +@patch("testgen.common.custom_test_validation.get_flavor_service") +def test_validate_custom_query_count_only(mock_get_flavor, mock_fetch): + mock_get_flavor.return_value = _flavor_service("limit") + mock_fetch.return_value = [(0,)] + + result = validate_custom_query( + _connection(), "demo", "SELECT * FROM orders WHERE total < 0", + ) + + assert isinstance(result, CustomQueryResult) + assert result.row_count == 0 + assert result.preview_rows == [] + # Only one fetch call: the count query + assert mock_fetch.call_count == 1 + # Verify the count query is wrapped with ERR_TABLE + count_sql = mock_fetch.call_args_list[0].args[1] + assert "SELECT COUNT(*)" in count_sql + assert "ERR_TABLE" in count_sql + + +@patch("testgen.common.custom_test_validation.fetch_from_target_db") +@patch("testgen.common.custom_test_validation.get_flavor_service") +def test_validate_custom_query_with_preview(mock_get_flavor, mock_fetch): + mock_get_flavor.return_value = _flavor_service("limit") + preview_row = MagicMock() + preview_row.keys.return_value = ["order_id", "amount"] + mock_fetch.side_effect = [ + [(3,)], # count query result + [preview_row], # preview query result + ] + + result = validate_custom_query( + _connection(), "demo", "SELECT * FROM orders WHERE total < 0", preview_limit=1, + ) + + assert result.row_count == 3 + assert result.preview_rows == [preview_row] + assert mock_fetch.call_count == 2 + preview_sql = mock_fetch.call_args_list[1].args[1] + assert "SELECT" in preview_sql + assert "ERR_TABLE" in preview_sql + assert "LIMIT 1" in preview_sql + + +@patch("testgen.common.custom_test_validation.fetch_from_target_db") +@patch("testgen.common.custom_test_validation.get_flavor_service") +def test_validate_custom_query_preview_skipped_when_no_rows(mock_get_flavor, mock_fetch): + mock_get_flavor.return_value = _flavor_service("limit") + mock_fetch.return_value = [(0,)] + + result = validate_custom_query( + _connection(), "demo", "SELECT 1 WHERE 1=0", preview_limit=5, + ) + + assert result.row_count == 0 + assert result.preview_rows == [] + # Preview query should NOT run when count is 0 + assert mock_fetch.call_count == 1 + + +@patch("testgen.common.custom_test_validation.fetch_from_target_db") +@patch("testgen.common.custom_test_validation.get_flavor_service") +def test_validate_custom_query_substitutes_data_schema(mock_get_flavor, mock_fetch): + mock_get_flavor.return_value = _flavor_service("limit") + mock_fetch.return_value = [(0,)] + + validate_custom_query( + _connection(), + "production_schema", + "SELECT * FROM {DATA_SCHEMA}.orders", + ) + + count_sql = mock_fetch.call_args_list[0].args[1] + # {DATA_SCHEMA} was substituted with the actual schema name + assert "production_schema.orders" in count_sql + assert "{DATA_SCHEMA}" not in count_sql + + +@patch("testgen.common.custom_test_validation.fetch_from_target_db") +@patch("testgen.common.custom_test_validation.get_flavor_service") +def test_validate_custom_query_strips_trailing_semicolon(mock_get_flavor, mock_fetch): + """Trailing semicolons break the subquery wrap — must be stripped.""" + mock_get_flavor.return_value = _flavor_service("limit") + mock_fetch.return_value = [(0,)] + + validate_custom_query( + _connection(), "demo", "SELECT 1; ", + ) + + count_sql = mock_fetch.call_args_list[0].args[1] + # The subquery should not contain a trailing semicolon + assert "SELECT 1)" in count_sql or "SELECT 1 )" in count_sql + # Specifically, the inner SELECT 1 should not be followed by ; inside the wrap + assert "SELECT 1;" not in count_sql + + +@patch("testgen.common.custom_test_validation.fetch_from_target_db") +@patch("testgen.common.custom_test_validation.get_flavor_service") +def test_validate_custom_query_uses_flavor_specific_limit(mock_get_flavor, mock_fetch): + """Oracle uses FETCH FIRST; MSSQL uses TOP — preview SQL must respect the flavor.""" + mock_get_flavor.return_value = _flavor_service("fetch") + preview_row = MagicMock() + mock_fetch.side_effect = [ + [(5,)], + [preview_row], + ] + + validate_custom_query( + _connection("oracle"), "demo", "SELECT * FROM t", preview_limit=1, + ) + + preview_sql = mock_fetch.call_args_list[1].args[1] + assert "FETCH FIRST 1 ROWS ONLY" in preview_sql + assert "LIMIT" not in preview_sql + + +@patch("testgen.common.custom_test_validation.fetch_from_target_db") +@patch("testgen.common.custom_test_validation.get_flavor_service") +def test_validate_custom_query_top_flavor_uses_prefix(mock_get_flavor, mock_fetch): + mock_get_flavor.return_value = _flavor_service("top") + preview_row = MagicMock() + mock_fetch.side_effect = [ + [(5,)], + [preview_row], + ] + + validate_custom_query( + _connection("mssql"), "demo", "SELECT * FROM t", preview_limit=1, + ) + + preview_sql = mock_fetch.call_args_list[1].args[1] + assert "TOP 1" in preview_sql + assert "LIMIT" not in preview_sql + + +@patch("testgen.common.custom_test_validation.fetch_from_target_db") +@patch("testgen.common.custom_test_validation.get_flavor_service") +def test_validate_custom_query_propagates_db_errors(mock_get_flavor, mock_fetch): + """DB errors propagate as-is — caller decides how to surface them.""" + mock_get_flavor.return_value = _flavor_service("limit") + mock_fetch.side_effect = Exception("syntax error at or near 'DROP'") + + with pytest.raises(Exception, match="syntax error"): + validate_custom_query(_connection(), "demo", "DROP TABLE orders") diff --git a/tests/unit/mcp/test_tools_test_definitions.py b/tests/unit/mcp/test_tools_test_definitions.py index e67e18a2..90213043 100644 --- a/tests/unit/mcp/test_tools_test_definitions.py +++ b/tests/unit/mcp/test_tools_test_definitions.py @@ -3,6 +3,7 @@ import pytest +from testgen.common.custom_test_validation import CustomQueryResult from testgen.mcp.exceptions import MCPUserError # -- list_tests --------------------------------------------------------------- @@ -640,6 +641,10 @@ def test_create_test_happy_path( saved = MagicMock() saved.id = uuid4() + saved.editable_fields.return_value = { + "test_active", "severity", "lock_refresh", "flagged", "test_description", + "threshold_value", "column_name", + } mock_td.return_value = saved mock_td.get_for_project.return_value = _make_td_summary() mock_notes.get_notes.return_value = [] @@ -650,9 +655,7 @@ def test_create_test_happy_path( test_suite_id=str(uuid4()), test_type="Alpha Truncation", table_name="orders", - column_name="email", - threshold_value="64", - severity="Warning", + fields={"column_name": "email", "threshold_value": "64", "severity": "Warning"}, ) # New shared body: entity-first heading + "Created in suite" lead-in @@ -665,18 +668,32 @@ def test_create_test_happy_path( saved.save.assert_called_once() +@patch("testgen.mcp.tools.test_definitions.TestDefinition") @patch("testgen.mcp.tools.test_definitions.TableGroup") @patch("testgen.mcp.tools.test_definitions.TestType") @patch("testgen.mcp.tools.test_definitions.resolve_test_type") @patch("testgen.mcp.tools.test_definitions.resolve_test_suite") def test_create_test_column_scope_requires_column_name( - mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, db_session_mock + mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, mock_td, db_session_mock, ): + """Column-scoped types: missing column_name → validate() raises before save.""" + from testgen.common.models.test_definition import InvalidTestDefinitionFields + mock_resolve_suite.return_value = _make_suite() mock_resolve_tt.return_value = "Alpha_Trunc" - mock_tt_model.get.return_value = _make_test_type() # column scope + mock_tt_model.get.return_value = _make_test_type() mock_tg.get.return_value = _make_table_group() + saved = MagicMock(id=uuid4()) + saved.editable_fields.return_value = { + "test_active", "severity", "lock_refresh", "flagged", "test_description", + "threshold_value", "column_name", + } + saved.validate.side_effect = InvalidTestDefinitionFields( + {"column_name": "required for test type `Alpha_Trunc`"} + ) + mock_td.return_value = saved + from testgen.mcp.tools.test_definitions import create_test with pytest.raises(MCPUserError) as exc_info: @@ -684,24 +701,34 @@ def test_create_test_column_scope_requires_column_name( test_suite_id=str(uuid4()), test_type="Alpha Truncation", table_name="orders", - threshold_value="64", + fields={"threshold_value": "64"}, ) assert "column_name" in str(exc_info.value) assert "rejected" in str(exc_info.value).lower() + saved.save.assert_not_called() +@patch("testgen.mcp.tools.test_definitions.TestDefinition") @patch("testgen.mcp.tools.test_definitions.TableGroup") @patch("testgen.mcp.tools.test_definitions.TestType") @patch("testgen.mcp.tools.test_definitions.resolve_test_type") @patch("testgen.mcp.tools.test_definitions.resolve_test_suite") -def test_create_test_custom_query_not_accepted_on_alpha_trunc( - mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, db_session_mock +def test_create_test_unknown_field_rejected_by_whitelist( + mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, mock_td, db_session_mock, ): + """Unknown field in ``fields`` (e.g. custom_query on Alpha_Trunc) is rejected by editable_fields whitelist.""" mock_resolve_suite.return_value = _make_suite() mock_resolve_tt.return_value = "Alpha_Trunc" - mock_tt_model.get.return_value = _make_test_type() # param_columns = {threshold_value} + mock_tt_model.get.return_value = _make_test_type() mock_tg.get.return_value = _make_table_group() + saved = MagicMock(id=uuid4()) + saved.editable_fields.return_value = { + "test_active", "severity", "lock_refresh", "flagged", "test_description", + "threshold_value", "column_name", + } + mock_td.return_value = saved + from testgen.mcp.tools.test_definitions import create_test with pytest.raises(MCPUserError) as exc_info: @@ -709,11 +736,11 @@ def test_create_test_custom_query_not_accepted_on_alpha_trunc( test_suite_id=str(uuid4()), test_type="Alpha Truncation", table_name="orders", - column_name="email", - threshold_value="64", - custom_query="SELECT 1", + fields={"column_name": "email", "threshold_value": "64", "custom_query": "SELECT 1"}, ) assert "custom_query" in str(exc_info.value) + assert "not editable" in str(exc_info.value) + saved.save.assert_not_called() @patch("testgen.mcp.tools.test_definitions.TestDefinitionNote") @@ -722,25 +749,25 @@ def test_create_test_custom_query_not_accepted_on_alpha_trunc( @patch("testgen.mcp.tools.test_definitions.TestType") @patch("testgen.mcp.tools.test_definitions.resolve_test_type") @patch("testgen.mcp.tools.test_definitions.resolve_test_suite") -def test_create_test_extra_params_pass_through( +def test_create_test_fields_dict_supports_test_type_params( mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, mock_td, mock_notes, db_session_mock, ): - """extra_params adds fields not in the explicit kwargs (e.g. window_days).""" + """``fields`` accepts any param in editable_fields — e.g. window_days for a trend test.""" mock_resolve_suite.return_value = _make_suite() mock_resolve_tt.return_value = "Some_Trend" - # Test type accepts threshold_value AND window_days mock_tt_model.get.return_value = _make_test_type( code="Some_Trend", param_columns={"threshold_value", "window_days"}, default_parm_columns="threshold_value,window_days", ) mock_tg.get.return_value = _make_table_group() - saved_td = MagicMock(id=uuid4()) - saved_td.editable_fields.return_value = { + + saved = MagicMock(id=uuid4()) + saved.editable_fields.return_value = { "test_active", "severity", "lock_refresh", "flagged", "test_description", - "threshold_value", "window_days", + "threshold_value", "window_days", "column_name", } - mock_td.return_value = saved_td + mock_td.return_value = saved mock_td.get_for_project.return_value = _make_td_summary() mock_notes.get_notes.return_value = [] @@ -750,84 +777,42 @@ def test_create_test_extra_params_pass_through( test_suite_id=str(uuid4()), test_type="Some Trend", table_name="orders", - column_name="email", - threshold_value="10", - extra_params={"window_days": "7"}, + fields={"column_name": "amount", "threshold_value": "10", "window_days": "7"}, ) - # threshold_value (from kwarg) and window_days (from extras) were both setattr'd on the TD - assert saved_td.threshold_value == "10" - assert saved_td.window_days == "7" - saved_td.validate.assert_called_once() - saved_td.save.assert_called_once() - - -@patch("testgen.mcp.tools.test_definitions.TableGroup") -@patch("testgen.mcp.tools.test_definitions.TestType") -@patch("testgen.mcp.tools.test_definitions.resolve_test_type") -@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") -def test_create_test_extra_params_conflict_rejected( - mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, db_session_mock, -): - """Passing the same field via both kwarg and extra_params is rejected.""" - mock_resolve_suite.return_value = _make_suite() - mock_resolve_tt.return_value = "Alpha_Trunc" - mock_tt_model.get.return_value = _make_test_type() - mock_tg.get.return_value = _make_table_group() - - from testgen.mcp.tools.test_definitions import create_test - - with pytest.raises(MCPUserError, match="both as named arguments and in"): - create_test( - test_suite_id=str(uuid4()), - test_type="Alpha Truncation", - table_name="orders", - column_name="email", - threshold_value="10", - extra_params={"threshold_value": "20"}, - ) - - -@patch("testgen.mcp.tools.test_definitions.TableGroup") -@patch("testgen.mcp.tools.test_definitions.TestType") -@patch("testgen.mcp.tools.test_definitions.resolve_test_type") -@patch("testgen.mcp.tools.test_definitions.resolve_test_suite") -def test_create_test_extra_params_unknown_field_rejected_via_validator( - mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, db_session_mock, -): - """Unknown field in extra_params surfaces through the validator's wrong-scope/unaccepted rules.""" - mock_resolve_suite.return_value = _make_suite() - mock_resolve_tt.return_value = "Alpha_Trunc" - mock_tt_model.get.return_value = _make_test_type() # param_columns = {threshold_value} - mock_tg.get.return_value = _make_table_group() - - from testgen.mcp.tools.test_definitions import create_test - - # custom_query isn't accepted by Alpha_Trunc — validator should reject - with pytest.raises(MCPUserError) as exc_info: - create_test( - test_suite_id=str(uuid4()), - test_type="Alpha Truncation", - table_name="orders", - column_name="email", - threshold_value="10", - extra_params={"custom_query": "SELECT 1"}, - ) - assert "custom_query" in str(exc_info.value) + # Both common and type-specific fields applied via setattr + assert saved.threshold_value == "10" + assert saved.window_days == "7" + saved.validate.assert_called_once() + saved.save.assert_called_once() +@patch("testgen.mcp.tools.test_definitions.TestDefinition") @patch("testgen.mcp.tools.test_definitions.TableGroup") @patch("testgen.mcp.tools.test_definitions.TestType") @patch("testgen.mcp.tools.test_definitions.resolve_test_type") @patch("testgen.mcp.tools.test_definitions.resolve_test_suite") def test_create_test_severity_invalid( - mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, db_session_mock + mock_resolve_suite, mock_resolve_tt, mock_tt_model, mock_tg, mock_td, db_session_mock, ): + """severity outside the StrEnum → validate() raises.""" + from testgen.common.models.test_definition import InvalidTestDefinitionFields + mock_resolve_suite.return_value = _make_suite() mock_resolve_tt.return_value = "Alpha_Trunc" mock_tt_model.get.return_value = _make_test_type() mock_tg.get.return_value = _make_table_group() + saved = MagicMock(id=uuid4()) + saved.editable_fields.return_value = { + "test_active", "severity", "lock_refresh", "flagged", "test_description", + "threshold_value", "column_name", + } + saved.validate.side_effect = InvalidTestDefinitionFields( + {"severity": "must be `Fail` or `Warning` (got `critical`)"} + ) + mock_td.return_value = saved + from testgen.mcp.tools.test_definitions import create_test with pytest.raises(MCPUserError) as exc_info: @@ -835,11 +820,10 @@ def test_create_test_severity_invalid( test_suite_id=str(uuid4()), test_type="Alpha Truncation", table_name="orders", - column_name="email", - threshold_value="64", - severity="critical", + fields={"column_name": "email", "threshold_value": "64", "severity": "critical"}, ) assert "severity" in str(exc_info.value) + saved.save.assert_not_called() # -- update_test -------------------------------------------------------------- @@ -932,19 +916,22 @@ def test_update_test_multi_field(mock_resolve_td, mock_tt_model, db_session_mock # -- validate_custom_test ----------------------------------------------------- -@patch("testgen.mcp.tools.test_definitions.fetch_from_target_db") +@patch("testgen.mcp.tools.test_definitions.validate_custom_query") +@patch("testgen.mcp.tools.test_definitions.TableGroup") @patch("testgen.mcp.tools.test_definitions.Connection") @patch("testgen.mcp.tools.test_definitions.resolve_test_suite") def test_validate_custom_test_would_pass_when_no_rows( - mock_resolve_suite, mock_conn, mock_fetch, db_session_mock + mock_resolve_suite, mock_conn, mock_tg, mock_validate, db_session_mock, ): + mock_resolve_suite.return_value = _make_suite() conn = MagicMock() conn.connection_name = "warehouse" conn.sql_flavor_code = "snowflake" conn.sql_flavor = "snowflake" mock_conn.get_by_table_group.return_value = conn - mock_fetch.return_value = [] + mock_tg.get.return_value = _make_table_group() + mock_validate.return_value = CustomQueryResult(row_count=0, preview_rows=[]) from testgen.mcp.tools.test_definitions import validate_custom_test @@ -952,19 +939,20 @@ def test_validate_custom_test_would_pass_when_no_rows( assert "ran successfully" in result.lower() assert "would pass" in result.lower() - assert "0 error rows" in result + assert "0 rows matching the failure criteria" in result @patch("testgen.mcp.permissions._compute_project_permissions") -@patch("testgen.mcp.tools.test_definitions.fetch_from_target_db") +@patch("testgen.mcp.tools.test_definitions.validate_custom_query") +@patch("testgen.mcp.tools.test_definitions.TableGroup") @patch("testgen.mcp.tools.test_definitions.Connection") @patch("testgen.mcp.tools.test_definitions.resolve_test_suite") def test_validate_custom_test_would_fail_shows_preview_with_view_pii( - mock_resolve_suite, mock_conn, mock_fetch, mock_compute, db_session_mock + mock_resolve_suite, mock_conn, mock_tg, mock_validate, mock_compute, db_session_mock, ): - # Grant view_pii on "demo" so values are visible in the preview. from testgen.mcp.permissions import ProjectPermissions + # Grant view_pii on "demo" so values are visible in the preview. perms = MagicMock(spec=ProjectPermissions) perms.allowed_codes = ["demo"] perms.codes_allowed_to.return_value = ["demo"] @@ -977,42 +965,45 @@ def test_validate_custom_test_would_fail_shows_preview_with_view_pii( conn.sql_flavor_code = "snowflake" conn.sql_flavor = "snowflake" mock_conn.get_by_table_group.return_value = conn + mock_tg.get.return_value = _make_table_group() row = MagicMock() row.keys.return_value = ["order_id", "amount"] row.__getitem__.side_effect = lambda k: {"order_id": "ORD-123", "amount": "-45.99"}[k] - mock_fetch.return_value = [row, row, row] + mock_validate.return_value = CustomQueryResult(row_count=3, preview_rows=[row]) from testgen.mcp.tools.test_definitions import validate_custom_test result = validate_custom_test(str(uuid4()), "SELECT * FROM orders WHERE amount < 0") assert "would fail" in result.lower() - assert "3 error row" in result + assert "3 row(s) matching the failure criteria" in result assert "order_id" in result assert "ORD-123" in result - # No redaction banner when view_pii is granted assert "[redacted]" not in result -@patch("testgen.mcp.tools.test_definitions.fetch_from_target_db") +@patch("testgen.mcp.tools.test_definitions.validate_custom_query") +@patch("testgen.mcp.tools.test_definitions.TableGroup") @patch("testgen.mcp.tools.test_definitions.Connection") @patch("testgen.mcp.tools.test_definitions.resolve_test_suite") def test_validate_custom_test_redacts_when_no_view_pii( - mock_resolve_suite, mock_conn, mock_fetch, db_session_mock + mock_resolve_suite, mock_conn, mock_tg, mock_validate, db_session_mock, ): - # Default fixture user has role_a with edit but not view_pii (view_pii not in test matrix → empty) + + # Default fixture user has role_a with edit but not view_pii. mock_resolve_suite.return_value = _make_suite() conn = MagicMock() conn.connection_name = "warehouse" conn.sql_flavor_code = "snowflake" conn.sql_flavor = "snowflake" mock_conn.get_by_table_group.return_value = conn + mock_tg.get.return_value = _make_table_group() row = MagicMock() row.keys.return_value = ["order_id", "customer_email"] row.__getitem__.side_effect = lambda k: {"order_id": "ORD-123", "customer_email": "jane@example.com"}[k] - mock_fetch.return_value = [row] + mock_validate.return_value = CustomQueryResult(row_count=1, preview_rows=[row]) from testgen.mcp.tools.test_definitions import validate_custom_test @@ -1021,17 +1012,20 @@ def test_validate_custom_test_redacts_when_no_view_pii( # Column names always visible assert "order_id" in result assert "customer_email" in result - # Values redacted because view_pii not granted in the default test matrix + # Values redacted; PII footer mentions permissions (no `view_pii` jargon) assert "[redacted]" in result assert "jane@example.com" not in result assert "ORD-123" not in result + assert "permissions to view PII" in result + assert "view_pii" not in result -@patch("testgen.mcp.tools.test_definitions.fetch_from_target_db") +@patch("testgen.mcp.tools.test_definitions.validate_custom_query") +@patch("testgen.mcp.tools.test_definitions.TableGroup") @patch("testgen.mcp.tools.test_definitions.Connection") @patch("testgen.mcp.tools.test_definitions.resolve_test_suite") def test_validate_custom_test_sql_error_surfaced( - mock_resolve_suite, mock_conn, mock_fetch, db_session_mock + mock_resolve_suite, mock_conn, mock_tg, mock_validate, db_session_mock, ): mock_resolve_suite.return_value = _make_suite() conn = MagicMock() @@ -1039,7 +1033,8 @@ def test_validate_custom_test_sql_error_surfaced( conn.sql_flavor_code = "postgresql" conn.sql_flavor = "postgresql" mock_conn.get_by_table_group.return_value = conn - mock_fetch.side_effect = Exception('syntax error at or near "FROMM"') + mock_tg.get.return_value = _make_table_group() + mock_validate.side_effect = Exception('syntax error at or near "FROMM"') from testgen.mcp.tools.test_definitions import validate_custom_test @@ -1069,7 +1064,7 @@ def test_validate_custom_test_missing_connection(mock_resolve_suite, mock_conn, def test_bulk_update_tests_disable_no_filter(mock_resolve_suite, mock_session, db_session_mock): mock_resolve_suite.return_value = _make_suite() result_mock = MagicMock() - result_mock.all.return_value = [(uuid4(),), (uuid4(),), (uuid4(),)] + result_mock.rowcount = 3 mock_session.return_value.execute.return_value = result_mock from testgen.mcp.tools.test_definitions import bulk_update_tests @@ -1089,7 +1084,7 @@ def test_bulk_update_tests_enable_with_table_filter( ): mock_resolve_suite.return_value = _make_suite() result_mock = MagicMock() - result_mock.all.return_value = [(uuid4(),)] + result_mock.rowcount = 1 mock_session.return_value.execute.return_value = result_mock from testgen.mcp.tools.test_definitions import bulk_update_tests @@ -1123,7 +1118,7 @@ def test_bulk_update_tests_invalid_action(mock_resolve_suite, mock_session, db_s def test_bulk_update_tests_no_match(mock_resolve_suite, mock_session, db_session_mock): mock_resolve_suite.return_value = _make_suite() result_mock = MagicMock() - result_mock.all.return_value = [] + result_mock.rowcount = 0 mock_session.return_value.execute.return_value = result_mock from testgen.mcp.tools.test_definitions import bulk_update_tests From 079331da9dc30ec350dec69be85e1654bb47dfd0 Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Tue, 12 May 2026 12:40:12 -0400 Subject: [PATCH 08/58] refactor: consolidate row-limiting clauses into FlavorService - Add ``FlavorService.row_limit_clauses(n)`` returning ``(prefix, suffix)`` SQL fragments for the flavor's row-limiting style (``LIMIT``/``TOP``/ ``FETCH FIRST``). Replaces three duplicate inline switches. - ``data_catalog.py`` and ``refresh_data_chars_query.py`` now call the method instead of branching on ``row_limiting_clause``. - Normalise the access-check projection to literal ``1`` across flavors (was ``*`` for ``TOP``, ``1`` for ``LIMIT``/``FETCH``); add a parametrised test asserting the SQL per flavor. - Fix a bug in ``validate_custom_query``: ``fetch_from_target_db`` returns ``RowMapping`` (column-name access), not tuples. Alias the count as ``row_count`` and access by name. Unit-test mocks updated to the real return shape. - Drop the now-redundant ``from __future__ import annotations`` from ``custom_test_validation.py``. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../queries/refresh_data_chars_query.py | 8 +--- testgen/common/custom_test_validation.py | 24 ++--------- .../common/database/flavor/flavor_service.py | 8 ++++ testgen/ui/views/data_catalog.py | 7 ++- .../queries/test_refresh_data_chars_query.py | 20 +++++++++ .../common/test_custom_test_validation.py | 43 +++++-------------- 6 files changed, 48 insertions(+), 62 deletions(-) diff --git a/testgen/commands/queries/refresh_data_chars_query.py b/testgen/commands/queries/refresh_data_chars_query.py index e5d72fa5..b494c308 100644 --- a/testgen/commands/queries/refresh_data_chars_query.py +++ b/testgen/commands/queries/refresh_data_chars_query.py @@ -124,12 +124,8 @@ def verify_access(self, table_name: str) -> tuple[str, None]: schema = self.table_group.table_group_schema quote = self.flavor_service.quote_character table_ref = f"{quote}{schema}{quote}.{quote}{table_name}{quote}" - if (row_limiting := self.flavor_service.row_limiting_clause) == "top": - query = f"SELECT TOP 1 * FROM {table_ref}" - elif row_limiting == "fetch": - query = f"SELECT 1 FROM {table_ref} FETCH FIRST 1 ROWS ONLY" - else: - query = f"SELECT 1 FROM {table_ref} LIMIT 1" + prefix, suffix = self.flavor_service.row_limit_clauses(1) + query = f"SELECT {prefix} 1 FROM {table_ref} {suffix}".strip() return (query, None) def get_staging_data_chars(self, data_chars: list[ColumnChars], run_date: datetime) -> list[list[str | bool | int]]: diff --git a/testgen/common/custom_test_validation.py b/testgen/common/custom_test_validation.py index 82de4754..bf2fb85b 100644 --- a/testgen/common/custom_test_validation.py +++ b/testgen/common/custom_test_validation.py @@ -9,20 +9,14 @@ - DDL/DML rejection — non-SELECT statements fail to parse as a subquery. """ -from __future__ import annotations - from dataclasses import dataclass, field -from typing import TYPE_CHECKING from sqlalchemy.engine import RowMapping from testgen.common.database.database_service import get_flavor_service, replace_params +from testgen.common.models.connection import Connection from testgen.ui.services.database_service import fetch_from_target_db -if TYPE_CHECKING: - from testgen.common.database.flavor.flavor_service import FlavorService - from testgen.common.models.connection import Connection - @dataclass class CustomQueryResult: @@ -52,24 +46,14 @@ def validate_custom_query( sql_with_schema = replace_params(custom_sql, {"DATA_SCHEMA": schema}).rstrip().rstrip(";") flavor_service = get_flavor_service(connection.sql_flavor) - count_sql = f"SELECT COUNT(*) FROM ({sql_with_schema}) ERR_TABLE" + count_sql = f"SELECT COUNT(*) AS row_count FROM ({sql_with_schema}) ERR_TABLE" count_rows = fetch_from_target_db(connection, count_sql) - row_count = int(count_rows[0][0]) if count_rows else 0 + row_count = int(count_rows[0]["row_count"]) if count_rows else 0 preview_rows: list[RowMapping] = [] if preview_limit > 0 and row_count > 0: - prefix, suffix = _row_limit_clauses(flavor_service, preview_limit) + prefix, suffix = flavor_service.row_limit_clauses(preview_limit) preview_sql = f"SELECT {prefix} * FROM ({sql_with_schema}) ERR_TABLE {suffix}".strip() preview_rows = fetch_from_target_db(connection, preview_sql) return CustomQueryResult(row_count=row_count, preview_rows=preview_rows) - - -def _row_limit_clauses(flavor_service: FlavorService, n: int) -> tuple[str, str]: - """Return (prefix, suffix) for limiting a SELECT to N rows on the given flavor.""" - clause = flavor_service.row_limiting_clause - if clause == "top": - return f"TOP {n}", "" - if clause == "fetch": - return "", f"FETCH FIRST {n} ROWS ONLY" - return "", f"LIMIT {n}" diff --git a/testgen/common/database/flavor/flavor_service.py b/testgen/common/database/flavor/flavor_service.py index a56ac8ba..09406ea4 100644 --- a/testgen/common/database/flavor/flavor_service.py +++ b/testgen/common/database/flavor/flavor_service.py @@ -93,6 +93,14 @@ class FlavorService: varchar_type = "VARCHAR(1000)" ddf_table_ref = "table_name" row_limiting_clause: RowLimitingClause = "limit" + + def row_limit_clauses(self, n: int) -> tuple[str, str]: + """Return ``(prefix, suffix)`` SQL fragments for limiting a SELECT to ``n`` rows.""" + if self.row_limiting_clause == "top": + return f"TOP {n}", "" + if self.row_limiting_clause == "fetch": + return "", f"FETCH FIRST {n} ROWS ONLY" + return "", f"LIMIT {n}" default_uppercase = False test_query = "SELECT 1" url_scheme = "postgresql" diff --git a/testgen/ui/views/data_catalog.py b/testgen/ui/views/data_catalog.py index 80ebb16b..5327a081 100644 --- a/testgen/ui/views/data_catalog.py +++ b/testgen/ui/views/data_catalog.py @@ -899,15 +899,14 @@ def get_preview_data( return {"title": title, "status": "ERR", "message": "Connection not found."} flavor_service = get_flavor_service(connection.sql_flavor) - row_limiting = flavor_service.row_limiting_clause + prefix, suffix = flavor_service.row_limit_clauses(100) quote = flavor_service.quote_character query = f""" SELECT DISTINCT - {"TOP 100" if row_limiting == "top" else ""} + {prefix} {f"{quote}{column_name}{quote}" if column_name else "*"} FROM {quote}{schema_name}{quote}.{quote}{table_name}{quote} - {"LIMIT 100" if row_limiting == "limit" else ""} - {"FETCH FIRST 100 ROWS ONLY" if row_limiting == "fetch" else ""} + {suffix} """ try: diff --git a/tests/unit/commands/queries/test_refresh_data_chars_query.py b/tests/unit/commands/queries/test_refresh_data_chars_query.py index 9118d586..47c52179 100644 --- a/tests/unit/commands/queries/test_refresh_data_chars_query.py +++ b/tests/unit/commands/queries/test_refresh_data_chars_query.py @@ -7,6 +7,26 @@ pytestmark = pytest.mark.unit +@pytest.mark.parametrize( + "flavor,expected_sql", + [ + ("postgresql", 'SELECT 1 FROM "test_schema"."orders" LIMIT 1'), + ("mssql", 'SELECT TOP 1 1 FROM "test_schema"."orders"'), + ("oracle", 'SELECT 1 FROM "test_schema"."orders" FETCH FIRST 1 ROWS ONLY'), + ], +) +def test_verify_access_uses_literal_1_projection(flavor, expected_sql): + """Access check uses literal ``1`` (not ``*``) — projection doesn't matter for an + existence/permission probe, and ``1`` avoids materialising columns on wide tables.""" + connection = Connection(sql_flavor=flavor) + table_group = TableGroup(table_group_schema="test_schema") + sql_generator = RefreshDataCharsSQL(connection, table_group) + + query, _ = sql_generator.verify_access("orders") + + assert query == expected_sql + + def test_include_exclude_mask_basic(): connection = Connection(sql_flavor="postgresql") table_group = TableGroup( diff --git a/tests/unit/common/test_custom_test_validation.py b/tests/unit/common/test_custom_test_validation.py index 78302fec..0e007067 100644 --- a/tests/unit/common/test_custom_test_validation.py +++ b/tests/unit/common/test_custom_test_validation.py @@ -6,14 +6,14 @@ from testgen.common.custom_test_validation import ( CustomQueryResult, - _row_limit_clauses, validate_custom_query, ) +from testgen.common.database.flavor.flavor_service import FlavorService -def _flavor_service(row_limiting: str = "limit") -> MagicMock: - svc = MagicMock() - svc.row_limiting_clause = row_limiting +def _flavor_service(row_limiting: str = "limit") -> FlavorService: + svc = FlavorService() + svc.row_limiting_clause = row_limiting # type: ignore[assignment] return svc @@ -23,27 +23,6 @@ def _connection(flavor: str = "postgresql") -> MagicMock: return conn -# -- _row_limit_clauses ------------------------------------------------------- - - -def test_row_limit_clauses_limit_flavor(): - prefix, suffix = _row_limit_clauses(_flavor_service("limit"), 5) - assert prefix == "" - assert suffix == "LIMIT 5" - - -def test_row_limit_clauses_top_flavor(): - prefix, suffix = _row_limit_clauses(_flavor_service("top"), 5) - assert prefix == "TOP 5" - assert suffix == "" - - -def test_row_limit_clauses_fetch_flavor(): - prefix, suffix = _row_limit_clauses(_flavor_service("fetch"), 5) - assert prefix == "" - assert suffix == "FETCH FIRST 5 ROWS ONLY" - - # -- validate_custom_query ---------------------------------------------------- @@ -51,7 +30,7 @@ def test_row_limit_clauses_fetch_flavor(): @patch("testgen.common.custom_test_validation.get_flavor_service") def test_validate_custom_query_count_only(mock_get_flavor, mock_fetch): mock_get_flavor.return_value = _flavor_service("limit") - mock_fetch.return_value = [(0,)] + mock_fetch.return_value = [{"row_count": 0}] result = validate_custom_query( _connection(), "demo", "SELECT * FROM orders WHERE total < 0", @@ -75,7 +54,7 @@ def test_validate_custom_query_with_preview(mock_get_flavor, mock_fetch): preview_row = MagicMock() preview_row.keys.return_value = ["order_id", "amount"] mock_fetch.side_effect = [ - [(3,)], # count query result + [{"row_count": 3}], # count query result [preview_row], # preview query result ] @@ -96,7 +75,7 @@ def test_validate_custom_query_with_preview(mock_get_flavor, mock_fetch): @patch("testgen.common.custom_test_validation.get_flavor_service") def test_validate_custom_query_preview_skipped_when_no_rows(mock_get_flavor, mock_fetch): mock_get_flavor.return_value = _flavor_service("limit") - mock_fetch.return_value = [(0,)] + mock_fetch.return_value = [{"row_count": 0}] result = validate_custom_query( _connection(), "demo", "SELECT 1 WHERE 1=0", preview_limit=5, @@ -112,7 +91,7 @@ def test_validate_custom_query_preview_skipped_when_no_rows(mock_get_flavor, moc @patch("testgen.common.custom_test_validation.get_flavor_service") def test_validate_custom_query_substitutes_data_schema(mock_get_flavor, mock_fetch): mock_get_flavor.return_value = _flavor_service("limit") - mock_fetch.return_value = [(0,)] + mock_fetch.return_value = [{"row_count": 0}] validate_custom_query( _connection(), @@ -131,7 +110,7 @@ def test_validate_custom_query_substitutes_data_schema(mock_get_flavor, mock_fet def test_validate_custom_query_strips_trailing_semicolon(mock_get_flavor, mock_fetch): """Trailing semicolons break the subquery wrap — must be stripped.""" mock_get_flavor.return_value = _flavor_service("limit") - mock_fetch.return_value = [(0,)] + mock_fetch.return_value = [{"row_count": 0}] validate_custom_query( _connection(), "demo", "SELECT 1; ", @@ -151,7 +130,7 @@ def test_validate_custom_query_uses_flavor_specific_limit(mock_get_flavor, mock_ mock_get_flavor.return_value = _flavor_service("fetch") preview_row = MagicMock() mock_fetch.side_effect = [ - [(5,)], + [{"row_count": 5}], [preview_row], ] @@ -170,7 +149,7 @@ def test_validate_custom_query_top_flavor_uses_prefix(mock_get_flavor, mock_fetc mock_get_flavor.return_value = _flavor_service("top") preview_row = MagicMock() mock_fetch.side_effect = [ - [(5,)], + [{"row_count": 5}], [preview_row], ] From 11bba6263c0b6d13c2d9a8455a3cf8a5e86228f9 Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Mon, 11 May 2026 15:46:24 -0400 Subject: [PATCH 09/58] =?UTF-8?q?feat(mcp):=20profiling=20L3=20=E2=80=94?= =?UTF-8?q?=20cross-column=20search,=20frequent=20values,=20patterns=20(TG?= =?UTF-8?q?-1067)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds three new MCP tools and extends list_column_profiles with predicate filters: - list_column_profiles: 17 new optional filter args (null/distinct/filled ratios, scores, pii/cde booleans, suggested_data_type enum, scoping enums) + ordering enum - get_column_frequent_values: top-N values for one column with PII redaction - get_column_patterns: top character patterns for one string column - search_columns: cross-scope column-name search with per-project match summary Folds the ticket's original find_columns_by_profile into list_column_profiles since input scope and output row shape were identical. CDE filter coalesces column- and table-level flags; MANUAL pii_flag is included in the High risk-level filter (matches the dq_score_weight_defaults seed). --- testgen/common/models/data_column.py | 160 +++++- testgen/common/models/profile_result.py | 41 +- testgen/common/profile_top_values.py | 51 ++ testgen/mcp/permissions.py | 8 + testgen/mcp/server.py | 6 + testgen/mcp/tools/common.py | 83 +++ testgen/mcp/tools/profiling.py | 432 +++++++++++++++- testgen/mcp/tools/source_data.py | 3 +- testgen/mcp/tools/test_definitions.py | 3 +- testgen/ui/views/data_catalog.py | 6 +- testgen/ui/views/profiling_results.py | 14 +- tests/unit/common/test_profile_top_values.py | 85 ++++ tests/unit/mcp/test_model_data_column.py | 61 +++ tests/unit/mcp/test_model_profile_result.py | 90 ++++ tests/unit/mcp/test_permissions.py | 34 ++ tests/unit/mcp/test_tools_common.py | 135 +++++ tests/unit/mcp/test_tools_profiling.py | 505 ++++++++++++++++++- 17 files changed, 1688 insertions(+), 29 deletions(-) create mode 100644 testgen/common/profile_top_values.py create mode 100644 tests/unit/common/test_profile_top_values.py create mode 100644 tests/unit/mcp/test_model_profile_result.py diff --git a/testgen/common/models/data_column.py b/testgen/common/models/data_column.py index 81d9c125..cf350519 100644 --- a/testgen/common/models/data_column.py +++ b/testgen/common/models/data_column.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from datetime import datetime +from enum import StrEnum from uuid import UUID, uuid4 from sqlalchemy import ( @@ -12,6 +13,7 @@ and_, asc, case, + desc, func, select, ) @@ -24,6 +26,65 @@ from testgen.common.models.profiling_run import ProfilingRun +class GeneralType(StrEnum): + """User-facing word values for the column ``general_type``.""" + + ALPHA = "Alpha" + NUMERIC = "Numeric" + DATETIME = "Datetime" + BOOLEAN = "Boolean" + TIME = "Time" + OTHER = "Other" + + +# Translates the user-facing words to the single-letter codes stored on +# ``data_column_chars.general_type`` for WHERE-clause matching. +GENERAL_TYPE_TO_CODE: dict[GeneralType, str] = { + GeneralType.ALPHA: "A", + GeneralType.NUMERIC: "N", + GeneralType.DATETIME: "D", + GeneralType.BOOLEAN: "B", + GeneralType.TIME: "T", + GeneralType.OTHER: "X", +} + + +class SuggestedDataType(StrEnum): + """Values accepted for the ``suggested_data_type`` argument.""" + + ANY = "Any" + INTEGER = "Integer" + NUMERIC = "Numeric" + VARCHAR = "Varchar" + DATE = "Date" + TIMESTAMP = "Timestamp" + BOOLEAN = "Boolean" + + +# Maps the user-facing word to the SQL-type prefix matched against +# ``datatype_suggestion`` (``Any`` is a sentinel — no prefix, just non-null check). +SUGGESTED_DATA_TYPE_TO_PREFIX: dict[SuggestedDataType, str | None] = { + SuggestedDataType.ANY: None, + SuggestedDataType.INTEGER: "INTEGER", + SuggestedDataType.NUMERIC: "NUMERIC", + SuggestedDataType.VARCHAR: "VARCHAR", + SuggestedDataType.DATE: "DATE", + SuggestedDataType.TIMESTAMP: "TIMESTAMP", + SuggestedDataType.BOOLEAN: "BOOLEAN", +} + + +class ColumnOrderBy(StrEnum): + """Values accepted for the ``order_by`` argument on column profile listings.""" + + NULL_RATIO = "Null Ratio" + DISTINCT_RATIO = "Distinct Ratio" + FILLED_RATIO = "Filled Ratio" + SCORE_PROFILING = "Profiling Score" + SCORE_TESTING = "Testing Score" + HYGIENE_COUNT = "Hygiene Count" + + @dataclass class ColumnProfileSummary(EntityMinimal): column_name: str @@ -124,6 +185,16 @@ class ColumnProfileDetail(EntityMinimal): profile_run_log_message: str | None +@dataclass +class ColumnSearchHit(EntityMinimal): + project_code: str + table_groups_id: UUID + table_groups_name: str + schema_name: str | None + table_name: str + column_name: str + + class DataColumnChars(Entity): __tablename__ = "data_column_chars" @@ -162,6 +233,7 @@ def list_for_table_group( *clauses, table_groups_id: UUID, profiling_run_id: UUID | None = None, + order_by: ColumnOrderBy | None = None, page: int, limit: int, ) -> tuple[list[ColumnProfileSummary], int]: @@ -246,9 +318,29 @@ def list_for_table_group( cls.drop_date.is_(None), *clauses, ) - .order_by(asc(cls.table_name), asc(cls.ordinal_position), asc(cls.column_name)) ) + null_ratio_expr = ProfileResult.null_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) + distinct_ratio_expr = ProfileResult.distinct_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) + filled_ratio_expr = ProfileResult.filled_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) + order_exprs: tuple + if order_by is ColumnOrderBy.NULL_RATIO: + order_exprs = (desc(null_ratio_expr).nulls_last(),) + elif order_by is ColumnOrderBy.DISTINCT_RATIO: + order_exprs = (asc(distinct_ratio_expr).nulls_last(),) + elif order_by is ColumnOrderBy.FILLED_RATIO: + order_exprs = (desc(filled_ratio_expr).nulls_last(),) + elif order_by is ColumnOrderBy.SCORE_PROFILING: + order_exprs = (asc(cls.dq_score_profiling).nulls_last(),) + elif order_by is ColumnOrderBy.SCORE_TESTING: + order_exprs = (asc(cls.dq_score_testing).nulls_last(),) + elif order_by is ColumnOrderBy.HYGIENE_COUNT: + order_exprs = (desc(func.coalesce(hygiene_subq.c.hygiene_issue_count, 0)),) + else: + order_exprs = (asc(cls.table_name), asc(cls.ordinal_position), asc(cls.column_name)) + + query = query.order_by(*order_exprs) + return cls._paginate(query, page=page, limit=limit, data_class=ColumnProfileSummary) @classmethod @@ -400,3 +492,69 @@ def get_column_detail( row = get_current_session().execute(query).mappings().first() return ColumnProfileDetail(**row) if row else None + + @classmethod + def search_by_name( + cls, + *clauses, + pattern: str, + page: int, + limit: int, + ) -> tuple[list[ColumnSearchHit], int]: + """Cross-table-group column-name search. Scoping clauses are passed in by the caller. + + ``pattern`` is matched with ``ILIKE``. Callers are expected to pre-wrap bare + tokens with ``%`` if substring search is desired; literal ``%`` / ``_`` from + the caller are honored as wildcards. + """ + # Local import: avoid circular dependency with TableGroup. + from testgen.common.models.table_group import TableGroup + + query = ( + select( + TableGroup.project_code, + TableGroup.id.label("table_groups_id"), + TableGroup.table_groups_name, + cls.schema_name, + cls.table_name, + cls.column_name, + ) + .join(TableGroup, TableGroup.id == cls.table_groups_id) + .where( + cls.column_name.ilike(pattern, escape="\\"), + cls.drop_date.is_(None), + *clauses, + ) + .order_by( + asc(TableGroup.project_code), + asc(TableGroup.table_groups_name), + asc(cls.table_name), + asc(cls.column_name), + ) + ) + + return cls._paginate(query, page=page, limit=limit, data_class=ColumnSearchHit) + + @classmethod + def summarize_matches_by_project( + cls, + *clauses, + pattern: str, + ) -> list[tuple[str, int]]: + """Per-project match counts for a column-name search — same WHERE shape as :meth:`search_by_name`.""" + # Local import: avoid circular dependency with TableGroup. + from testgen.common.models.table_group import TableGroup + + query = ( + select(TableGroup.project_code, func.count().label("match_count")) + .select_from(cls) + .join(TableGroup, TableGroup.id == cls.table_groups_id) + .where( + cls.column_name.ilike(pattern, escape="\\"), + cls.drop_date.is_(None), + *clauses, + ) + .group_by(TableGroup.project_code) + .order_by(TableGroup.project_code) + ) + return [(row.project_code, row.match_count) for row in get_current_session().execute(query).all()] diff --git a/testgen/common/models/profile_result.py b/testgen/common/models/profile_result.py index 31f37337..0eef47d6 100644 --- a/testgen/common/models/profile_result.py +++ b/testgen/common/models/profile_result.py @@ -1,7 +1,7 @@ from datetime import datetime from uuid import UUID, uuid4 -from sqlalchemy import BigInteger, Column, Float, ForeignKey, Integer, Numeric, String, asc +from sqlalchemy import BigInteger, Column, Float, ForeignKey, Integer, Numeric, String, asc, desc from sqlalchemy.dialects import postgresql from testgen.common.models.entity import Entity @@ -85,3 +85,42 @@ class ProfileResult(Entity): query_error: str | None = Column(String) _default_order_by = (asc(position), asc(column_name)) + + @classmethod + def get_for_column( + cls, + table_groups_id: UUID, + table_name: str, + column_name: str, + profiling_run_id: UUID | None = None, + ) -> "ProfileResult | None": + """Fetch the profile-results row for one column. + + Resolves to the explicit ``profiling_run_id`` when given, otherwise to the + column's latest profile run (via ``data_column_chars.last_complete_profile_run_id``). + Returns ``None`` when no row exists. + """ + # Local import: data_column imports ProfileResult at module top. + from testgen.common.models.data_column import DataColumnChars + + clauses = [ + cls.table_groups_id == table_groups_id, + cls.table_name == table_name, + cls.column_name == column_name, + ] + if profiling_run_id is not None: + clauses.append(cls.profile_run_id == profiling_run_id) + else: + latest = list( + DataColumnChars.select_where( + DataColumnChars.table_groups_id == table_groups_id, + DataColumnChars.table_name == table_name, + DataColumnChars.column_name == column_name, + ) + ) + if not latest or latest[0].last_complete_profile_run_id is None: + return None + clauses.append(cls.profile_run_id == latest[0].last_complete_profile_run_id) + + rows = list(cls.select_where(*clauses, order_by=(desc(cls.profile_run_id),))) + return rows[0] if rows else None diff --git a/testgen/common/profile_top_values.py b/testgen/common/profile_top_values.py new file mode 100644 index 00000000..ac7d9298 --- /dev/null +++ b/testgen/common/profile_top_values.py @@ -0,0 +1,51 @@ +"""Parsers for the ``top_freq_values`` and ``top_patterns`` fields written by profiling. + +Both fields are stored as delimited strings on ``profile_results``. This module +splits them back into structured rows; format quirks (separators, leading markers, +values containing the separator) are handled here so they only need fixing in one +place. +""" + + +def parse_top_freq_values(raw: str | None) -> list[tuple[str, int]]: + """Parse ``top_freq_values`` text into ``[(value, count), ...]``. + + Stored format: ``| value | count\\n| value | count ...`` — each row begins with + ``| ``, value and count are separated by `` | ``, rows are joined by ``\\n``. + Uses :py:meth:`str.rpartition` so values containing `` | `` parse correctly + (the count is always the rightmost segment). + """ + if not raw: + return [] + body = raw[2:] if raw.startswith("| ") else raw + rows: list[tuple[str, int]] = [] + for part in body.split("\n| "): + if " | " not in part: + continue + value, _, count = part.rpartition(" | ") + try: + rows.append((value.strip(), int(count.strip()))) + except ValueError: + continue + return rows + + +def parse_top_patterns(raw: str | None) -> list[tuple[str, int]]: + """Parse ``top_patterns`` text into ``[(pattern, count), ...]``. + + Stored format: alternating ``count | pattern | count | pattern ...`` (SQL + templates emit segments separated by `` | ``; the odd-indexed segment is the + pattern, the even-indexed is the count). + """ + if not raw: + return [] + parts = [p.strip() for p in raw.split(" | ")] + rows: list[tuple[str, int]] = [] + for index in range(0, len(parts) - 1, 2): + try: + count = int(parts[index]) + except ValueError: + continue + pattern = parts[index + 1] + rows.append((pattern, count)) + return rows diff --git a/testgen/mcp/permissions.py b/testgen/mcp/permissions.py index dce78000..0850c753 100644 --- a/testgen/mcp/permissions.py +++ b/testgen/mcp/permissions.py @@ -39,6 +39,14 @@ def has_access(self, project_code: str) -> bool: """For filtering lists — no exception, just a bool.""" return project_code in self.allowed_codes + def has_permission(self, permission: str, project_code: str) -> bool: + """Whether the user has ``permission`` on ``project_code`` (single-check predicate). + + For per-row checks in tight loops, prefer caching the result of + :meth:`codes_allowed_to` once and using a set lookup. + """ + return project_code in self.codes_allowed_to(permission) + def verify_access(self, project_code: str, not_found: "str | MCPPermissionDenied") -> None: """Raise MCPPermissionDenied if user can't access this project. diff --git a/testgen/mcp/server.py b/testgen/mcp/server.py index f506bef6..77fa0bc0 100644 --- a/testgen/mcp/server.py +++ b/testgen/mcp/server.py @@ -152,12 +152,15 @@ def build_mcp_server( update_hygiene_issue, ) from testgen.mcp.tools.profiling import ( + get_column_frequent_values, + get_column_patterns, get_column_profile_detail, get_profiling_run, get_table, list_column_profiles, list_profiling_runs, list_profiling_summaries, + search_columns, ) from testgen.mcp.tools.reference import ( column_profile_fields_resource, @@ -237,6 +240,9 @@ def safe_prompt(fn): safe_tool(list_profiling_runs) safe_tool(get_profiling_run) safe_tool(get_column_profile_detail) + safe_tool(get_column_frequent_values) + safe_tool(get_column_patterns) + safe_tool(search_columns) safe_tool(run_tests) safe_tool(run_profiling) safe_tool(cancel_test_run) diff --git a/testgen/mcp/tools/common.py b/testgen/mcp/tools/common.py index 78794976..13235af1 100644 --- a/testgen/mcp/tools/common.py +++ b/testgen/mcp/tools/common.py @@ -7,6 +7,12 @@ from testgen.common.date_service import parse_since from testgen.common.enums import ImpactDimension, QualityDimension from testgen.common.models import get_current_session +from testgen.common.models.data_column import ( + GENERAL_TYPE_TO_CODE, + ColumnOrderBy, + GeneralType, + SuggestedDataType, +) from testgen.common.models.hygiene_issue import Disposition, HygieneIssueType, IssueLikelihood, PiiRisk from testgen.common.models.job_execution import JobStatus from testgen.common.models.profiling_run import ProfilingRun @@ -176,6 +182,83 @@ def parse_issue_likelihood_list(values: list[str]) -> list[IssueLikelihood]: return parsed +# Maps the user-facing display label to the stored ``pii_flag`` middle segment +# (``A//``). Mirrors ``_PII_TYPE_MAP`` in ``profiling.py``. +_PII_CATEGORY_TO_CODE: dict[str, str] = { + "ID": "ID", + "Name": "NAME", + "Demographic": "DEMO", + "Contact": "CONTACT", +} + + +def build_ilike_pattern(raw: str) -> str: + """Prepare a free-text input for an ``ILIKE`` clause. + + Escapes literal underscores (which column names commonly contain) so they + match as themselves rather than as the SQL single-character wildcard. When + the input contains an explicit ``%``, honor it as the caller's wildcard; + otherwise wrap the input with ``%...%`` for substring match. + + Pair with ``column.ilike(pattern, escape="\\\\")`` at the call site. + """ + escaped = raw.replace("_", r"\_") + return escaped if "%" in escaped else f"%{escaped}%" + + +def parse_pii_category(value: str) -> str: + """Validate a pii_category value and return the stored ``pii_flag`` middle segment.""" + code = _PII_CATEGORY_TO_CODE.get(value) + if code is None: + valid = ", ".join(_PII_CATEGORY_TO_CODE) + raise MCPUserError(f"Invalid pii_category `{value}`. Valid values: {valid}") + return code + + +def parse_general_type(value: str) -> str: + """Validate a user-facing ``general_type`` word and return the stored single-letter code. + + Accepts ``Alpha`` / ``Numeric`` / ``Datetime`` / ``Boolean`` / ``Time`` / ``Other``; + returns ``A`` / ``N`` / ``D`` / ``B`` / ``T`` / ``X`` respectively (the values stored + on ``data_column_chars.general_type``). + """ + try: + member = GeneralType(value) + except ValueError as err: + valid = ", ".join(t.value for t in GeneralType) + raise MCPUserError(f"Invalid general_type `{value}`. Valid values: {valid}") from err + return GENERAL_TYPE_TO_CODE[member] + + +def parse_suggested_data_type(value: str) -> SuggestedDataType: + try: + return SuggestedDataType(value) + except ValueError as err: + valid = ", ".join(t.value for t in SuggestedDataType) + raise MCPUserError(f"Invalid suggested_data_type `{value}`. Valid values: {valid}") from err + + +def parse_column_order_by(value: str) -> ColumnOrderBy: + try: + return ColumnOrderBy(value) + except ValueError as err: + valid = ", ".join(o.value for o in ColumnOrderBy) + raise MCPUserError(f"Invalid order_by `{value}`. Valid values: {valid}") from err + + +# ``pii_flag`` encodes risk as a single-character prefix: ``A`` (High), ``B`` (Moderate), ``C`` (Low). +_PII_RISK_LEVEL_TO_CODE: dict[str, str] = {"High": "A", "Moderate": "B", "Low": "C"} + + +def parse_pii_risk_level(value: str) -> str: + """Validate a column-profile pii_risk_level filter and return the stored prefix code.""" + code = _PII_RISK_LEVEL_TO_CODE.get(value) + if code is None: + valid = ", ".join(_PII_RISK_LEVEL_TO_CODE) + raise MCPUserError(f"Invalid pii_risk_level `{value}`. Valid values: {valid}") + return code + + def parse_pii_risk_list(values: list[str]) -> list[PiiRisk]: parsed: list[PiiRisk] = [] invalid: list[str] = [] diff --git a/testgen/mcp/tools/profiling.py b/testgen/mcp/tools/profiling.py index aea98447..4e3034f7 100644 --- a/testgen/mcp/tools/profiling.py +++ b/testgen/mcp/tools/profiling.py @@ -1,23 +1,39 @@ import dataclasses from uuid import UUID +from sqlalchemy import func, or_ + from testgen.common.models import with_database_session -from testgen.common.models.data_column import ColumnProfileDetail, ColumnProfileSummary, DataColumnChars +from testgen.common.models.data_column import ( + SUGGESTED_DATA_TYPE_TO_PREFIX, + ColumnOrderBy, + ColumnProfileDetail, + ColumnProfileSummary, + DataColumnChars, +) from testgen.common.models.data_table import DataTable from testgen.common.models.job_execution import JobExecution +from testgen.common.models.profile_result import ProfileResult from testgen.common.models.profiling_run import ProfilingRun, ProfilingRunSummary from testgen.common.models.scheduler import RUN_PROFILE_JOB_KEY from testgen.common.models.table_group import TableGroup, TableGroupSummary -from testgen.common.pii_masking import mask_profiling_pii +from testgen.common.pii_masking import PII_REDACTED, mask_profiling_pii +from testgen.common.profile_top_values import parse_top_freq_values, parse_top_patterns from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.permissions import get_project_permissions, mcp_permission from testgen.mcp.tools.common import ( DocGroup, + build_ilike_pattern, format_page_footer, format_page_info, format_run_duration, next_scheduled_run, + parse_column_order_by, + parse_general_type, + parse_pii_category, + parse_pii_risk_level, parse_run_status_filter, + parse_suggested_data_type, parse_uuid, resolve_profiling_run, resolve_table_group, @@ -81,21 +97,78 @@ def list_column_profiles( table_name: str | None = None, columns: list[str] | None = None, job_execution_id: str | None = None, + null_ratio_above: float | None = None, + null_ratio_below: float | None = None, + distinct_ratio_above: float | None = None, + distinct_ratio_below: float | None = None, + filled_ratio_above: float | None = None, + filled_ratio_below: float | None = None, + score_profiling_above: float | None = None, + score_profiling_below: float | None = None, + score_testing_above: float | None = None, + score_testing_below: float | None = None, + pii: bool | None = None, + cde: bool | None = None, + suggested_data_type: str | None = None, + general_type: str | None = None, + functional_data_type: str | None = None, + pii_category: str | None = None, + pii_risk_level: str | None = None, + order_by: str | None = None, limit: int = 100, page: int = 1, ) -> str: - """List per-column profile headers (~14 fields each) — the Layer 1 scan of profiling results across columns in a table group. + """List per-column profile headers across a table group, with optional profile-predicate filters. Args: table_group_id: UUID of the table group, e.g. from `get_data_inventory`. table_name: Optional — scope to one table (case-sensitive). columns: Optional — specific column names to include (case-sensitive). job_execution_id: UUID of a profiling run, e.g. from `get_table` or - `list_profiling_summaries`. When omitted, each column uses its own - latest run. - limit: Page size (default 100). + `list_profiling_summaries`. When omitted, each column uses its own latest run. + null_ratio_above: Match columns whose null fraction exceeds this value + (e.g. `0.2` for above 20% null). + null_ratio_below: Match columns whose null fraction is below this value. + distinct_ratio_above: Match columns whose distinct-value fraction exceeds this + value (e.g. `0.95` for near-unique columns). + distinct_ratio_below: Match columns whose distinct-value fraction is below this + value (e.g. `0.001` for low cardinality). + filled_ratio_above: Match columns whose dummy/placeholder-value fraction exceeds + this value. + filled_ratio_below: Match columns whose dummy/placeholder-value fraction is below + this value. + score_profiling_above: Match columns whose Profiling Score is above this value. + score_profiling_below: Match columns whose Profiling Score is below this value. + score_testing_above: Match columns whose Testing Score is above this value. + score_testing_below: Match columns whose Testing Score is below this value. + pii: When `true`, match columns flagged as PII; when `false`, exclude PII columns. + cde: When `true`, match columns flagged as a Critical Data Element (directly + or inherited from the table); when `false`, exclude CDE columns. + suggested_data_type: Match columns where profiling suggests a more suitable data + type. Pass `Any` for any mismatch, or a concrete type (`Integer`, `Numeric`, + `Varchar`, `Date`, `Timestamp`, `Boolean`) to filter mismatches whose + suggestion starts with that type. Columns where the suggestion matches the + column's stored type are always excluded. + general_type: Broad type classification — + `Alpha`, `Numeric`, `Datetime`, `Boolean`, `Time`, or `Other`. + functional_data_type: Substring match (case-insensitive) on Semantic Data Type. + Use a cluster prefix to catch related variants — `Period` matches + `Period`, `Period Month`, `Period Year`, etc.; `ID` matches `ID`, `ID-FK`, + `ID-Unique`, etc.; `Transactional Date` matches all of its variants. Bare + tokens auto-wrap with `%`; an explicit `%` in the input is honored as a + wildcard. The set of values is open-ended — discover available values by + listing columns without this filter, then narrow. + pii_category: PII category — `ID`, `Name`, `Demographic`, or `Contact`. + pii_risk_level: PII risk level — `High`, `Moderate`, or `Low`. + order_by: Sort key — `Null Ratio`, `Distinct Ratio`, `Filled Ratio`, + `Profiling Score`, `Testing Score`, or `Hygiene Count`. Defaults to + table/column position. + limit: Page size (default 100, max 500). page: Page number starting at 1 (default 1). """ + validate_page(page) + validate_limit(limit, 500) + tg = resolve_table_group(table_group_id) profiling_run_id: UUID | None = None @@ -111,10 +184,98 @@ def list_column_profiles( if columns: clauses.append(DataColumnChars.column_name.in_(columns)) + if null_ratio_above is not None: + clauses.append(ProfileResult.null_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) > null_ratio_above) + if null_ratio_below is not None: + clauses.append(ProfileResult.null_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) < null_ratio_below) + if distinct_ratio_above is not None: + clauses.append( + ProfileResult.distinct_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) > distinct_ratio_above + ) + if distinct_ratio_below is not None: + clauses.append( + ProfileResult.distinct_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) < distinct_ratio_below + ) + if filled_ratio_above is not None: + clauses.append( + ProfileResult.filled_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) > filled_ratio_above + ) + if filled_ratio_below is not None: + clauses.append( + ProfileResult.filled_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) < filled_ratio_below + ) + + if score_profiling_above is not None: + clauses.append(DataColumnChars.dq_score_profiling > score_profiling_above) + if score_profiling_below is not None: + clauses.append(DataColumnChars.dq_score_profiling < score_profiling_below) + if score_testing_above is not None: + clauses.append(DataColumnChars.dq_score_testing > score_testing_above) + if score_testing_below is not None: + clauses.append(DataColumnChars.dq_score_testing < score_testing_below) + + if pii is True: + clauses.append(DataColumnChars.pii_flag.isnot(None)) + elif pii is False: + clauses.append(DataColumnChars.pii_flag.is_(None)) + + if cde is True: + # A column is a CDE when either it or its parent table is flagged. + clauses.append( + or_( + DataColumnChars.critical_data_element.is_(True), + DataTable.critical_data_element.is_(True), + ) + ) + elif cde is False: + clauses.append( + DataColumnChars.critical_data_element.isnot(True), + ) + clauses.append( + DataTable.critical_data_element.isnot(True), + ) + + if suggested_data_type is not None: + prefix = SUGGESTED_DATA_TYPE_TO_PREFIX[parse_suggested_data_type(suggested_data_type)] + if prefix is None: + clauses.append(ProfileResult.datatype_suggestion.isnot(None)) + else: + clauses.append(ProfileResult.datatype_suggestion.ilike(f"{prefix}%")) + + if general_type is not None: + clauses.append(DataColumnChars.general_type == parse_general_type(general_type)) + if functional_data_type is not None: + if not functional_data_type.strip(): + raise MCPUserError("`functional_data_type` cannot be empty.") + clauses.append( + DataColumnChars.functional_data_type.ilike( + build_ilike_pattern(functional_data_type), escape="\\" + ) + ) + if pii_category is not None: + category = parse_pii_category(pii_category) + # ``pii_flag`` stores ``//``; match on the middle segment. + clauses.append(DataColumnChars.pii_flag.like(f"%/{category}/%")) + if pii_risk_level is not None: + risk_code = parse_pii_risk_level(pii_risk_level) + # ``MANUAL`` is user-set PII, weighted equivalent to ``A`` (High) by ``dq_score_weight_defaults``. + if risk_code == "A": + clauses.append( + or_( + DataColumnChars.pii_flag.like("A/%"), + DataColumnChars.pii_flag == "MANUAL", + ) + ) + else: + clauses.append(DataColumnChars.pii_flag.like(f"{risk_code}/%")) + + order_value: ColumnOrderBy | None = parse_column_order_by(order_by) if order_by else None + data, total = DataColumnChars.list_for_table_group( *clauses, table_groups_id=tg.id, profiling_run_id=profiling_run_id, + order_by=order_value, page=page, limit=limit, ) @@ -488,6 +649,45 @@ def _format_std_pattern(value: str | None) -> str | None: return _STD_PATTERN_LABELS.get(value, value.replace("_", " ").title()) +# --------------------------------------------------------------------------- +# Shared helpers for single-column tools (frequent values, patterns) +# --------------------------------------------------------------------------- + + +def _load_profile_for_column( + tg: TableGroup, + table_name: str, + column_name: str, + job_execution_id: str | None, +) -> tuple[ProfileResult, ProfilingRun]: + """Resolve and load the profile-results row for one column, paired with its ``ProfilingRun``.""" + profiling_run: ProfilingRun | None = None + if job_execution_id: + profiling_run = resolve_profiling_run(job_execution_id) + if profiling_run.table_groups_id != tg.id: + raise MCPResourceNotAccessible("Profiling run", job_execution_id) + profile = ProfileResult.get_for_column( + table_groups_id=tg.id, + table_name=table_name, + column_name=column_name, + profiling_run_id=profiling_run.id if profiling_run else None, + ) + if profile is None: + raise MCPResourceNotAccessible("Column profile", f"{table_name}.{column_name}") + if profiling_run is None: + profiling_run = ProfilingRun.get(profile.profile_run_id) + if profiling_run is None: + raise MCPResourceNotAccessible("Profiling run", str(profile.profile_run_id)) + return profile, profiling_run + + +def _is_pii_redacted_for_caller(tg: TableGroup, profile: ProfileResult) -> bool: + """Decide whether to redact PII values for this caller + column.""" + if not profile.pii_flag: + return False + return not get_project_permissions().has_permission("view_pii", tg.project_code) + + @with_database_session @mcp_permission("catalog") def get_column_profile_detail( @@ -536,9 +736,8 @@ def get_column_profile_detail( if detail.profile_run_status in ("Running", "Error", "Cancelled"): _raise_run_not_ready(detail) - perms = get_project_permissions() payload = dataclasses.asdict(detail) - if tg.project_code not in perms.codes_allowed_to("view_pii") and detail.pii_flag: + if detail.pii_flag and not get_project_permissions().has_permission("view_pii", tg.project_code): mask_profiling_pii(payload, {detail.column_name}) return _render_column_profile_detail(payload) @@ -710,3 +909,220 @@ def _render_boolean_block(doc: MdDoc, p: dict) -> None: def _render_unknown_block(doc: MdDoc, p: dict) -> None: _render_counts(doc, p) + + +# --------------------------------------------------------------------------- +# Single-column tools — frequent values and patterns +# --------------------------------------------------------------------------- + + +@with_database_session +@mcp_permission("catalog") +def get_column_frequent_values( + table_group_id: str, + table_name: str, + column_name: str, + job_execution_id: str | None = None, +) -> str: + """Get the top frequent values for one column from its profile run, with row counts and percentages. + + Profiling captures the top 10 values; when the column has more distinct values, a + trailing `Other Values (N)` row aggregates the remainder. + + Args: + table_group_id: UUID of the table group, e.g. from `get_data_inventory`. + table_name: Table name exactly as stored in TestGen (case-sensitive). + column_name: Column name exactly as stored in TestGen (case-sensitive). + job_execution_id: UUID of a profiling run. When omitted, uses the column's + latest profile run. + """ + tg = resolve_table_group(table_group_id) + profile, profiling_run = _load_profile_for_column(tg, table_name, column_name, job_execution_id) + + doc = MdDoc() + doc.heading(1, f"Frequent values: {table_name}.{column_name}") + doc.field("Table group", tg.id, code=True) + doc.field("Profiling Run", profiling_run.job_execution_id, code=True) + doc.field("Records", profile.record_ct) + doc.field("Distinct values", profile.distinct_value_ct) + if profile.pii_flag: + doc.field("PII", _format_pii(profile.pii_flag)) + + rows = parse_top_freq_values(profile.top_freq_values) + if not rows: + doc.text( + f"_Frequency data not available — high cardinality " + f"(distinct count: {profile.distinct_value_ct})._" + ) + return doc.render() + + redact = _is_pii_redacted_for_caller(tg, profile) + record_ct = profile.record_ct or 0 + display_rows: list[list[object]] = [] + for value, count in rows: + pct = (count / record_ct * 100) if record_ct else None + display_value = PII_REDACTED if redact else value + display_rows.append([display_value, count, f"{pct:.2f}%" if pct is not None else None]) + + doc.heading(2, "Top values") + doc.table(["Value", "Count", "% of records"], display_rows) + return doc.render() + + +@with_database_session +@mcp_permission("catalog") +def get_column_patterns( + table_group_id: str, + table_name: str, + column_name: str, + job_execution_id: str | None = None, +) -> str: + """Get the top character patterns for one string column from its profile run. + + Patterns use shorthand: `A` = uppercase letter, `a` = lowercase letter, `N` = digit; + every other character (whitespace, punctuation, symbols) appears literally. Examples: + `Aaaaaaaa` (capitalized word), `NNNN-NN-NN` (ISO-like date), `aaa@aaa.aaa` (email-shaped). + Profiling captures the top 5 patterns. + + Args: + table_group_id: UUID of the table group, e.g. from `get_data_inventory`. + table_name: Table name exactly as stored in TestGen (case-sensitive). + column_name: Column name exactly as stored in TestGen (case-sensitive). + job_execution_id: UUID of a profiling run. When omitted, uses the column's + latest profile run. + """ + tg = resolve_table_group(table_group_id) + profile, profiling_run = _load_profile_for_column(tg, table_name, column_name, job_execution_id) + + doc = MdDoc() + doc.heading(1, f"Character patterns: {table_name}.{column_name}") + doc.field("Table group", tg.id, code=True) + doc.field("Profiling Run", profiling_run.job_execution_id, code=True) + doc.field("Records", profile.record_ct) + doc.field("Distinct values", profile.distinct_value_ct) + + if profile.general_type and profile.general_type != "A": + doc.text("_Pattern data not available — column is not a string type._") + return doc.render() + + rows = parse_top_patterns(profile.top_patterns) + if not rows: + doc.text( + f"_Pattern data not available — high cardinality " + f"(distinct count: {profile.distinct_value_ct})._" + ) + return doc.render() + + record_ct = profile.record_ct or 0 + display_rows: list[list[object]] = [] + for pattern, count in rows: + pct = (count / record_ct * 100) if record_ct else None + display_rows.append([pattern, count, f"{pct:.2f}%" if pct is not None else None]) + + doc.heading(2, "Top patterns") + doc.table(["Pattern", "Count", "% of records"], display_rows, code=[0]) + return doc.render() + + +# --------------------------------------------------------------------------- +# Cross-scope column-name search +# --------------------------------------------------------------------------- + + +@with_database_session +@mcp_permission("catalog") +def search_columns( + pattern: str, + project_code: str | None = None, + table_group_id: str | None = None, + limit: int = 100, + page: int = 1, +) -> str: + """Search columns by name across one or many projects (bare tokens auto-wrap as `%token%`; explicit `%` honored as a wildcard). + + Args: + pattern: Column-name search pattern. Case-insensitive. + project_code: Optional — scope to one project. Mutually exclusive with + `table_group_id`. + table_group_id: Optional — scope to one table group. Mutually exclusive + with `project_code`. + limit: Page size (default 100, max 500). + page: Page number starting at 1 (default 1). + """ + validate_page(page) + validate_limit(limit, 500) + + if not pattern or not pattern.strip(): + raise MCPUserError("`pattern` is required and cannot be empty.") + effective_pattern = build_ilike_pattern(pattern) + + if project_code is not None and table_group_id is not None: + raise MCPUserError("Pass either `project_code` or `table_group_id`, not both.") + + perms = get_project_permissions() + clauses: list = [] + + if table_group_id is not None: + tg = resolve_table_group(table_group_id) + clauses.append(DataColumnChars.table_groups_id == tg.id) + scope_label = f"table group `{table_group_id}`" + elif project_code is not None: + perms.verify_access( + project_code, + not_found=MCPResourceNotAccessible("Project", project_code), + ) + clauses.append(TableGroup.project_code == project_code) + scope_label = f"project `{project_code}`" + else: + # The @mcp_permission decorator guarantees ``allowed_codes`` is non-empty by + # the time the body runs (it raises MCPPermissionDenied otherwise). + clauses.append(TableGroup.project_code.in_(list(perms.allowed_codes))) + scope_label = "all accessible projects" + + data, total = DataColumnChars.search_by_name( + *clauses, + pattern=effective_pattern, + page=page, + limit=limit, + ) + + if not data: + if page > 1: + return f"No columns matching `{pattern}` on page {page} (total: {total})." + return f"No columns matching `{pattern}` in {scope_label}." + + doc = MdDoc() + doc.heading(1, f"Columns matching `{pattern}` in {scope_label}") + + page_info = format_page_info(total, page, limit) + if page_info: + doc.text(page_info) + + # Per-project match summary when no scope was provided. + if project_code is None and table_group_id is None: + summary_rows = DataColumnChars.summarize_matches_by_project( + *clauses, + pattern=effective_pattern, + ) + if summary_rows: + doc.heading(2, "Matches by project") + doc.table( + ["Project", "Matches"], + [[code_, count] for code_, count in summary_rows], + code=[0], + ) + + doc.heading(2, "Columns") + doc.table( + ["Project", "Table group", "Schema", "Table", "Column"], + [ + [hit.project_code, hit.table_groups_name, hit.schema_name, hit.table_name, hit.column_name] + for hit in data + ], + code=[0, 1], + ) + + footer = format_page_footer(total, page, limit) + if footer: + doc.text(footer) + return doc.render() diff --git a/testgen/mcp/tools/source_data.py b/testgen/mcp/tools/source_data.py index 1e75b78c..1b3fb0bb 100644 --- a/testgen/mcp/tools/source_data.py +++ b/testgen/mcp/tools/source_data.py @@ -101,8 +101,7 @@ def get_source_data( validate_limit(limit, 500) context = _resolve_context(test_definition_id, reference_date) - perms = get_project_permissions() - mask_pii = context.get("project_code") not in perms.codes_allowed_to("view_pii") + mask_pii = not get_project_permissions().has_permission("view_pii", context.get("project_code")) result: SourceDataResult = fetch_test_result_source_data(context, limit, mask_pii) diff --git a/testgen/mcp/tools/test_definitions.py b/testgen/mcp/tools/test_definitions.py index d95d27c4..8a87e677 100644 --- a/testgen/mcp/tools/test_definitions.py +++ b/testgen/mcp/tools/test_definitions.py @@ -533,8 +533,7 @@ def validate_custom_test(test_suite_id: str, custom_sql: str) -> str: if table_group is None: raise MCPUserError("Test suite is not associated with a table group.") - perms = get_project_permissions() - can_view_pii = suite.project_code in perms.codes_allowed_to("view_pii") + can_view_pii = get_project_permissions().has_permission("view_pii", suite.project_code) doc = MdDoc() doc.heading(1, "Custom test dry-run") diff --git a/testgen/ui/views/data_catalog.py b/testgen/ui/views/data_catalog.py index 5327a081..6636e41e 100644 --- a/testgen/ui/views/data_catalog.py +++ b/testgen/ui/views/data_catalog.py @@ -23,6 +23,7 @@ mask_profiling_pii, mask_source_data_pii, ) +from testgen.common.profile_top_values import parse_top_freq_values, parse_top_patterns from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets.download_dialog import ( FILE_DATA_TYPE, @@ -499,13 +500,12 @@ def get_excel_report_data(update_progress: PROGRESS_UPDATE_TYPE, table_group: Ta axis=1, ) data["top_freq_values"] = data["top_freq_values"].apply( - lambda val: "\n".join([f"{part.split(' | ')[1]} | {part.split(' | ')[0]}" for part in val[2:].split("\n| ")]) + lambda val: "\n".join(f"{count} | {value}" for value, count in parse_top_freq_values(val)) if not pd.isna(val) and val != PII_REDACTED else val ) - nl = "\n" # For Python 3.11 compatibility data["top_patterns"] = data["top_patterns"].apply( - lambda val: "".join([f"{part}{nl if index % 2 else ' | '}" for index, part in enumerate(val.split(" | "))]) + lambda val: "\n".join(f"{count} | {pattern}" for pattern, count in parse_top_patterns(val)) if not pd.isna(val) and val != PII_REDACTED else val ) diff --git a/testgen/ui/views/profiling_results.py b/testgen/ui/views/profiling_results.py index 83608a12..d777d297 100644 --- a/testgen/ui/views/profiling_results.py +++ b/testgen/ui/views/profiling_results.py @@ -16,6 +16,7 @@ mask_profiling_pii, mask_source_data_pii, ) +from testgen.common.profile_top_values import parse_top_freq_values, parse_top_patterns from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets.download_dialog import ( FILE_DATA_TYPE, @@ -325,21 +326,12 @@ def get_excel_report_data( def _format_top_freq_values(val): if not val or val == PII_REDACTED: return val - lines = [] - for part in val[2:].split("\n| "): - left, right = part.split(" | ") - lines.append(f"{right} | {left}") - return "\n".join(lines) + return "\n".join(f"{count} | {value}" for value, count in parse_top_freq_values(val)) def _format_top_patterns(val): if not val or val == PII_REDACTED: return val - parts = val.split(" | ") - formatted = [] - for index, part in enumerate(parts): - separator = "\n" if index % 2 else " | " - formatted.append(f"{part}{separator}") - return "".join(formatted) + return "\n".join(f"{count} | {pattern}" for pattern, count in parse_top_patterns(val)) data["top_freq_values"] = data["top_freq_values"].apply(_format_top_freq_values) data["top_patterns"] = data["top_patterns"].apply(_format_top_patterns) diff --git a/tests/unit/common/test_profile_top_values.py b/tests/unit/common/test_profile_top_values.py new file mode 100644 index 00000000..e3c74661 --- /dev/null +++ b/tests/unit/common/test_profile_top_values.py @@ -0,0 +1,85 @@ +from testgen.common.profile_top_values import parse_top_freq_values, parse_top_patterns + +# --- parse_top_freq_values --- + + +def test_parse_top_freq_values_three_rows(): + raw = "| Mexico | 182\n| USA | 176\n| Canada | 144" + assert parse_top_freq_values(raw) == [("Mexico", 182), ("USA", 176), ("Canada", 144)] + + +def test_parse_top_freq_values_with_other_values_aggregate_row(): + # The profiling pipeline emits a synthetic "Other Values (N)" row when distinct count > 10. + raw = "| a | 5\n| b | 4\n| Other Values (8) | 20" + assert parse_top_freq_values(raw) == [("a", 5), ("b", 4), ("Other Values (8)", 20)] + + +def test_parse_top_freq_values_value_containing_separator(): + # rpartition: count is always rightmost, so a value with " | " in it parses correctly. + raw = "| user | password | 42" + assert parse_top_freq_values(raw) == [("user | password", 42)] + + +def test_parse_top_freq_values_none_input(): + assert parse_top_freq_values(None) == [] + + +def test_parse_top_freq_values_empty_input(): + assert parse_top_freq_values("") == [] + + +def test_parse_top_freq_values_skips_unparseable_count(): + raw = "| good | 10\n| bad | not_a_number\n| also_good | 5" + assert parse_top_freq_values(raw) == [("good", 10), ("also_good", 5)] + + +def test_parse_top_freq_values_skips_rows_without_separator(): + raw = "alone\n| good | 5" + assert parse_top_freq_values(raw) == [("good", 5)] + + +def test_parse_top_freq_values_trims_whitespace_around_value(): + raw = "| spacey | 7" + assert parse_top_freq_values(raw) == [("spacey", 7)] + + +def test_parse_top_freq_values_tolerates_missing_leading_marker(): + raw = "alone | 9" + assert parse_top_freq_values(raw) == [("alone", 9)] + + +# --- parse_top_patterns --- + + +def test_parse_top_patterns_three_pairs(): + raw = "326 | Aaaaaa | 176 | AAA | 50 | aaa" + assert parse_top_patterns(raw) == [("Aaaaaa", 326), ("AAA", 176), ("aaa", 50)] + + +def test_parse_top_patterns_email_shape(): + raw = "200 | aaa@aaa.aaa" + assert parse_top_patterns(raw) == [("aaa@aaa.aaa", 200)] + + +def test_parse_top_patterns_none_input(): + assert parse_top_patterns(None) == [] + + +def test_parse_top_patterns_empty_input(): + assert parse_top_patterns("") == [] + + +def test_parse_top_patterns_skips_pair_with_unparseable_count(): + raw = "10 | good | xx | bad | 5 | also_good" + assert parse_top_patterns(raw) == [("good", 10), ("also_good", 5)] + + +def test_parse_top_patterns_dangling_odd_segment_ignored(): + # An odd number of segments — the trailing count without a pattern is dropped. + raw = "10 | Aaa | 99" + assert parse_top_patterns(raw) == [("Aaa", 10)] + + +def test_parse_top_patterns_trims_pattern_whitespace(): + raw = "5 | NNNN-NN-NN " + assert parse_top_patterns(raw) == [("NNNN-NN-NN", 5)] diff --git a/tests/unit/mcp/test_model_data_column.py b/tests/unit/mcp/test_model_data_column.py index 102c811f..aa8b0181 100644 --- a/tests/unit/mcp/test_model_data_column.py +++ b/tests/unit/mcp/test_model_data_column.py @@ -238,3 +238,64 @@ def test_get_column_detail_no_pin_uses_last_complete_profile_run_id(session_mock query = call_args[0][0] sql_str = str(query.compile(compile_kwargs={"literal_binds": True})) assert "last_complete_profile_run_id" in sql_str + + +# ---------------------------------------------------------------------- +# DataColumnChars.search_by_name +# ---------------------------------------------------------------------- + + +@patch.object(DataColumnChars, "_paginate") +def test_search_by_name_joins_table_group_and_orders_for_stable_pagination(paginate_mock): + paginate_mock.return_value = ([], 0) + + DataColumnChars.search_by_name(pattern="%email%", page=1, limit=10) + + query = paginate_mock.call_args[0][0] + sql_str = str(query.compile(compile_kwargs={"literal_binds": True})) + # Join to table_groups + ILIKE on column_name + the expected ordering for stable paging. + assert "table_groups" in sql_str.lower() + assert "ilike" in sql_str.lower() or "like" in sql_str.lower() + assert "ORDER BY" in sql_str + assert "project_code" in sql_str + assert "%email%" in sql_str + + +@patch.object(DataColumnChars, "_paginate") +def test_search_by_name_excludes_dropped_columns(paginate_mock): + paginate_mock.return_value = ([], 0) + + DataColumnChars.search_by_name(pattern="%x%", page=1, limit=10) + + query = paginate_mock.call_args[0][0] + sql_str = str(query.compile(compile_kwargs={"literal_binds": True})) + assert "drop_date IS NULL" in sql_str + + +# ---------------------------------------------------------------------- +# DataColumnChars.summarize_matches_by_project +# ---------------------------------------------------------------------- + + +@patch("testgen.common.models.data_column.get_current_session") +def test_summarize_matches_by_project_returns_project_count_tuples(session_mock): + row_a = type("Row", (), {"project_code": "DEFAULT", "match_count": 6})() + row_b = type("Row", (), {"project_code": "DEMO_2", "match_count": 1})() + session_mock.return_value.execute.return_value.all.return_value = [row_a, row_b] + + result = DataColumnChars.summarize_matches_by_project(pattern="%email%") + + assert result == [("DEFAULT", 6), ("DEMO_2", 1)] + + +@patch("testgen.common.models.data_column.get_current_session") +def test_summarize_matches_by_project_groups_and_orders_by_project(session_mock): + session_mock.return_value.execute.return_value.all.return_value = [] + + DataColumnChars.summarize_matches_by_project(pattern="%x%") + + query = session_mock.return_value.execute.call_args[0][0] + sql_str = str(query.compile(compile_kwargs={"literal_binds": True})) + assert "GROUP BY" in sql_str + assert "ORDER BY" in sql_str + assert "project_code" in sql_str.lower() diff --git a/tests/unit/mcp/test_model_profile_result.py b/tests/unit/mcp/test_model_profile_result.py new file mode 100644 index 00000000..1a4ad475 --- /dev/null +++ b/tests/unit/mcp/test_model_profile_result.py @@ -0,0 +1,90 @@ +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +from testgen.common.models.profile_result import ProfileResult + + +@patch("testgen.common.models.profile_result.ProfileResult.select_where") +@patch("testgen.common.models.data_column.DataColumnChars.select_where") +def test_get_for_column_returns_row_when_run_pinned(dcc_select, pr_select): + pinned_run_id = uuid4() + profile = MagicMock(spec=ProfileResult) + pr_select.return_value = [profile] + + result = ProfileResult.get_for_column( + table_groups_id=uuid4(), + table_name="customers", + column_name="email", + profiling_run_id=pinned_run_id, + ) + + assert result is profile + # When a profile run is explicitly pinned, we should not fall back to data_column_chars. + dcc_select.assert_not_called() + + +@patch("testgen.common.models.profile_result.ProfileResult.select_where") +@patch("testgen.common.models.data_column.DataColumnChars.select_where") +def test_get_for_column_resolves_latest_run_when_unpinned(dcc_select, pr_select): + latest_run_id = uuid4() + column = MagicMock() + column.last_complete_profile_run_id = latest_run_id + dcc_select.return_value = [column] + profile = MagicMock(spec=ProfileResult) + pr_select.return_value = [profile] + + result = ProfileResult.get_for_column( + table_groups_id=uuid4(), + table_name="customers", + column_name="email", + ) + + assert result is profile + dcc_select.assert_called_once() + + +@patch("testgen.common.models.profile_result.ProfileResult.select_where") +@patch("testgen.common.models.data_column.DataColumnChars.select_where") +def test_get_for_column_returns_none_when_column_unknown(dcc_select, pr_select): + dcc_select.return_value = [] + + result = ProfileResult.get_for_column( + table_groups_id=uuid4(), + table_name="customers", + column_name="ghost", + ) + + assert result is None + pr_select.assert_not_called() + + +@patch("testgen.common.models.profile_result.ProfileResult.select_where") +@patch("testgen.common.models.data_column.DataColumnChars.select_where") +def test_get_for_column_returns_none_when_column_never_profiled(dcc_select, pr_select): + column = MagicMock() + column.last_complete_profile_run_id = None + dcc_select.return_value = [column] + + result = ProfileResult.get_for_column( + table_groups_id=uuid4(), + table_name="customers", + column_name="email", + ) + + assert result is None + pr_select.assert_not_called() + + +@patch("testgen.common.models.profile_result.ProfileResult.select_where") +@patch("testgen.common.models.data_column.DataColumnChars.select_where") +def test_get_for_column_returns_none_when_pinned_run_has_no_row(dcc_select, pr_select): + pr_select.return_value = [] + + result = ProfileResult.get_for_column( + table_groups_id=uuid4(), + table_name="customers", + column_name="email", + profiling_run_id=uuid4(), + ) + + assert result is None diff --git a/tests/unit/mcp/test_permissions.py b/tests/unit/mcp/test_permissions.py index 4b058295..0f97cbe4 100644 --- a/tests/unit/mcp/test_permissions.py +++ b/tests/unit/mcp/test_permissions.py @@ -206,6 +206,40 @@ def test_has_access(): assert perms.has_access("proj_b") is False +# --- ProjectPermissions.has_permission --- + + +def test_has_permission_true_when_role_grants_it(): + perms = ProjectPermissions( + memberships={"proj_a": "role_a", "proj_b": "role_c"}, + permission="catalog", + username="test_user", + ) + # role_a is in the "view" allowlist; role_c is not. + assert perms.has_permission("view", "proj_a") is True + assert perms.has_permission("view", "proj_b") is False + + +def test_has_permission_false_when_project_not_member(): + perms = ProjectPermissions( + memberships={"proj_a": "role_a"}, + permission="catalog", + username="test_user", + ) + assert perms.has_permission("view", "proj_other") is False + + +def test_has_permission_decoupled_from_decorator_permission(): + # The decorator was "catalog", but we can query any permission. + perms = ProjectPermissions( + memberships={"proj_a": "role_a"}, + permission="catalog", + username="test_user", + ) + assert perms.has_permission("edit", "proj_a") is True + assert perms.has_permission("catalog", "proj_a") is True + + # --- get_project_permissions --- diff --git a/tests/unit/mcp/test_tools_common.py b/tests/unit/mcp/test_tools_common.py index 4a9eb432..72b82a46 100644 --- a/tests/unit/mcp/test_tools_common.py +++ b/tests/unit/mcp/test_tools_common.py @@ -330,3 +330,138 @@ def test_resolve_profiling_run_inaccessible_project(mock_pr_cls, mock_get_perms, def test_resolve_profiling_run_invalid_uuid(): with pytest.raises(MCPUserError, match="Invalid job_execution_id"): resolve_profiling_run("not-a-uuid") + + +# --- parse_pii_category --- + + +def test_parse_pii_category_translates_display_label_to_stored_code(): + from testgen.mcp.tools.common import parse_pii_category + assert parse_pii_category("ID") == "ID" + assert parse_pii_category("Name") == "NAME" + assert parse_pii_category("Demographic") == "DEMO" + assert parse_pii_category("Contact") == "CONTACT" + + +def test_parse_pii_category_rejects_stored_code_form(): + from testgen.mcp.tools.common import parse_pii_category + with pytest.raises(MCPUserError, match="Invalid pii_category `NAME`"): + parse_pii_category("NAME") + + +def test_parse_pii_category_lists_valid_values_in_error(): + from testgen.mcp.tools.common import parse_pii_category + with pytest.raises(MCPUserError, match="Valid values:") as exc_info: + parse_pii_category("Address") + for label in ("ID", "Name", "Demographic", "Contact"): + assert label in str(exc_info.value) + + +# --- parse_pii_risk_level --- + + +def test_parse_pii_risk_level_translates_label_to_stored_prefix(): + from testgen.mcp.tools.common import parse_pii_risk_level + assert parse_pii_risk_level("High") == "A" + assert parse_pii_risk_level("Moderate") == "B" + assert parse_pii_risk_level("Low") == "C" + + +def test_parse_pii_risk_level_rejects_unknown(): + from testgen.mcp.tools.common import parse_pii_risk_level + with pytest.raises(MCPUserError, match="Invalid pii_risk_level `Critical`"): + parse_pii_risk_level("Critical") + + +# --- parse_general_type --- + + +def test_parse_general_type_translates_word_to_letter_code(): + from testgen.mcp.tools.common import parse_general_type + assert parse_general_type("Alpha") == "A" + assert parse_general_type("Numeric") == "N" + assert parse_general_type("Datetime") == "D" + assert parse_general_type("Boolean") == "B" + assert parse_general_type("Time") == "T" + assert parse_general_type("Other") == "X" + + +def test_parse_general_type_rejects_letter_code_input(): + from testgen.mcp.tools.common import parse_general_type + with pytest.raises(MCPUserError, match="Invalid general_type `A`"): + parse_general_type("A") + + +def test_parse_general_type_is_case_sensitive(): + from testgen.mcp.tools.common import parse_general_type + with pytest.raises(MCPUserError): + parse_general_type("alpha") + + +# --- parse_suggested_data_type --- + + +def test_parse_suggested_data_type_accepts_title_case(): + from testgen.common.models.data_column import SuggestedDataType + from testgen.mcp.tools.common import parse_suggested_data_type + assert parse_suggested_data_type("Any") is SuggestedDataType.ANY + assert parse_suggested_data_type("Integer") is SuggestedDataType.INTEGER + assert parse_suggested_data_type("Varchar") is SuggestedDataType.VARCHAR + + +def test_parse_suggested_data_type_rejects_uppercase(): + from testgen.mcp.tools.common import parse_suggested_data_type + with pytest.raises(MCPUserError, match="Invalid suggested_data_type `INTEGER`"): + parse_suggested_data_type("INTEGER") + + +def test_parse_suggested_data_type_lists_valid_values_in_error(): + from testgen.mcp.tools.common import parse_suggested_data_type + with pytest.raises(MCPUserError) as exc_info: + parse_suggested_data_type("Bogus") + for label in ("Any", "Integer", "Numeric", "Varchar", "Date", "Timestamp", "Boolean"): + assert label in str(exc_info.value) + + +# --- parse_column_order_by --- + + +def test_parse_column_order_by_accepts_display_form(): + from testgen.common.models.data_column import ColumnOrderBy + from testgen.mcp.tools.common import parse_column_order_by + assert parse_column_order_by("Null Ratio") is ColumnOrderBy.NULL_RATIO + assert parse_column_order_by("Profiling Score") is ColumnOrderBy.SCORE_PROFILING + assert parse_column_order_by("Hygiene Count") is ColumnOrderBy.HYGIENE_COUNT + + +def test_parse_column_order_by_rejects_snake_case(): + from testgen.mcp.tools.common import parse_column_order_by + with pytest.raises(MCPUserError, match="Invalid order_by `null_ratio`"): + parse_column_order_by("null_ratio") + + +# --- build_ilike_pattern --- + + +def test_build_ilike_pattern_wraps_bare_token(): + from testgen.mcp.tools.common import build_ilike_pattern + assert build_ilike_pattern("email") == "%email%" + + +def test_build_ilike_pattern_escapes_literal_underscore(): + from testgen.mcp.tools.common import build_ilike_pattern + # Column names commonly contain underscores; treat them as literal, not as SQL wildcards. + assert build_ilike_pattern("user_id") == r"%user\_id%" + + +def test_build_ilike_pattern_honors_explicit_percent(): + from testgen.mcp.tools.common import build_ilike_pattern + # Caller-supplied % means "I'm doing my own wildcards" — don't double-wrap. + assert build_ilike_pattern("%email") == "%email" + assert build_ilike_pattern("user%") == "user%" + + +def test_build_ilike_pattern_escapes_underscores_even_with_explicit_percent(): + from testgen.mcp.tools.common import build_ilike_pattern + # The `_` escape is unconditional — explicit `%` doesn't suppress it. + assert build_ilike_pattern("user_%") == r"user\_%" diff --git a/tests/unit/mcp/test_tools_profiling.py b/tests/unit/mcp/test_tools_profiling.py index 5a4ec01f..e5d38052 100644 --- a/tests/unit/mcp/test_tools_profiling.py +++ b/tests/unit/mcp/test_tools_profiling.py @@ -4,7 +4,7 @@ import pytest -from testgen.common.models.data_column import ColumnProfileDetail, ColumnProfileSummary +from testgen.common.models.data_column import ColumnProfileDetail, ColumnProfileSummary, DataColumnChars from testgen.common.pii_masking import PII_REDACTED from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.permissions import ProjectPermissions @@ -1311,3 +1311,506 @@ def test_get_column_profile_detail_query_error_section(mock_tg_cls, mock_dcc_cls assert "Profiling Error" in result assert "ORA-01017" in result + + +# ---------------------------------------------------------------------- +# list_column_profiles — predicate filters +# ---------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_null_ratio_above_adds_clause(mock_tg_cls, mock_dcc_cls, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + mock_dcc_cls.list_for_table_group.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), null_ratio_above=0.2) + + clauses = mock_dcc_cls.list_for_table_group.call_args[0] + assert any("null_value_ct" in str(c) for c in clauses) + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_pii_true_adds_is_not_null_clause(mock_tg_cls, mock_method, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), pii=True) + + sql = _compile_clauses(mock_method) + assert "pii_flag IS NOT NULL" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_cde_true_coalesces_column_and_table_flag( + mock_tg_cls, mock_method, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), cde=True) + + sql = _compile_clauses(mock_method) + assert "data_column_chars.critical_data_element IS true" in sql + assert "data_table_chars.critical_data_element IS true" in sql + assert "OR" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_suggested_data_type_any_uses_is_not_null( + mock_tg_cls, mock_method, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), suggested_data_type="Any") + + sql = _compile_clauses(mock_method) + assert "datatype_suggestion IS NOT NULL" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_suggested_data_type_concrete_uses_prefix_ilike( + mock_tg_cls, mock_method, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), suggested_data_type="Integer") + + sql = _compile_clauses(mock_method) + assert "INTEGER%" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_general_type_translates_word_to_letter( + mock_tg_cls, mock_method, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), general_type="Numeric") + + sql = _compile_clauses(mock_method) + assert "general_type = 'N'" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_pii_category_translated_to_stored_code( + mock_tg_cls, mock_method, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), pii_category="Contact") + + sql = _compile_clauses(mock_method) + assert "%/CONTACT/%" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_pii_risk_level_high_includes_manual( + mock_tg_cls, mock_method, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), pii_risk_level="High") + + sql = _compile_clauses(mock_method) + assert "'A/%'" in sql and "'MANUAL'" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_pii_risk_level_moderate_does_not_include_manual( + mock_tg_cls, mock_method, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), pii_risk_level="Moderate") + + sql = _compile_clauses(mock_method) + assert "'B/%'" in sql + assert "MANUAL" not in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_functional_data_type_uses_ilike( + mock_tg_cls, mock_method, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), functional_data_type="Person Given") + + sql = _compile_clauses(mock_method) + # Default dialect renders ILIKE as ``LOWER(col) LIKE LOWER(pat) ESCAPE`` — same semantic. + assert "LIKE" in sql.upper() + assert "%Person Given%" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_functional_data_type_underscore_escaped( + mock_tg_cls, mock_method, db_session_mock, +): + """Underscores in the input must be escaped (column names commonly contain them).""" + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), functional_data_type="ID_FK") + + sql = _compile_clauses(mock_method) + # The escape clause appears, and the underscore is escaped in the pattern. + assert "ID\\_FK" in sql or "ID\\\\_FK" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_functional_data_type_empty_rejected(mock_tg_cls, mock_method, db_session_mock): + mock_tg_cls.get.return_value = _mock_table_group() + + from testgen.mcp.tools.profiling import list_column_profiles + with pytest.raises(MCPUserError, match="`functional_data_type` cannot be empty"): + list_column_profiles(str(uuid4()), functional_data_type=" ") + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_order_by_passes_enum_to_model(mock_tg_cls, mock_method, db_session_mock): + from testgen.common.models.data_column import ColumnOrderBy + + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), order_by="Null Ratio") + + assert mock_method.call_args.kwargs["order_by"] is ColumnOrderBy.NULL_RATIO + + +def _compile_clauses(mock_method): + """Compile the *clauses arg of a captured ``list_for_table_group`` call into a single SQL string.""" + clauses = mock_method.call_args[0] + return " ".join(str(c.compile(compile_kwargs={"literal_binds": True})) for c in clauses) + + +# ---------------------------------------------------------------------- +# get_column_frequent_values +# ---------------------------------------------------------------------- + + +def _mock_profile_result(**overrides): + pr = MagicMock() + pr.profile_run_id = uuid4() + pr.record_ct = 500 + pr.distinct_value_ct = 3 + pr.pii_flag = None + pr.general_type = "A" + pr.top_freq_values = "| Mexico | 200\n| USA | 180\n| Canada | 120" + pr.top_patterns = "200 | Aaaaaa | 100 | AAA" + for k, v in overrides.items(): + setattr(pr, k, v) + return pr + + +def _mock_profiling_run_for_tg(tg_id): + pr = MagicMock() + pr.id = uuid4() + pr.table_groups_id = tg_id + pr.job_execution_id = uuid4() + return pr + + +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_frequent_values_happy_path(mock_tg_cls, mock_pr_cls, mock_run_cls, db_session_mock): + tg = _mock_table_group() + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_for_column.return_value = _mock_profile_result() + mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + + from testgen.mcp.tools.profiling import get_column_frequent_values + result = get_column_frequent_values(str(uuid4()), "customers", "country") + + assert "Frequent values: customers.country" in result + assert "Mexico" in result and "USA" in result and "Canada" in result + assert "40.00%" in result # 200/500 + assert "Top values" in result + + +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_frequent_values_surfaces_job_execution_id_not_profile_run_id( + mock_tg_cls, mock_pr_cls, mock_run_cls, db_session_mock, +): + tg = _mock_table_group() + mock_tg_cls.get.return_value = tg + profile = _mock_profile_result() + mock_pr_cls.get_for_column.return_value = profile + run = _mock_profiling_run_for_tg(tg.id) + mock_run_cls.get.return_value = run + + from testgen.mcp.tools.profiling import get_column_frequent_values + result = get_column_frequent_values(str(uuid4()), "customers", "country") + + # The internal profile_run_id PK must not leak; only the job_execution_id is followable. + assert str(run.job_execution_id) in result + assert str(profile.profile_run_id) not in result + + +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_frequent_values_pii_value_redacted_when_caller_lacks_view_pii( + mock_tg_cls, mock_pr_cls, mock_run_cls, db_session_mock, +): + tg = _mock_table_group(project_code="demo") + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_for_column.return_value = _mock_profile_result( + pii_flag="B/CONTACT/Email", + top_freq_values="| alice@example.com | 5\n| bob@example.com | 3", + ) + mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + + # Default test conftest grants no view_pii (TEST_PERM_MATRIX has no entry). + from testgen.mcp.tools.profiling import get_column_frequent_values + result = get_column_frequent_values(str(uuid4()), "customers", "email") + + assert PII_REDACTED in result + assert "alice@example.com" not in result + + +@patch("testgen.mcp.permissions._compute_project_permissions") +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_frequent_values_pii_value_visible_with_view_pii_grant( + mock_tg_cls, mock_pr_cls, mock_run_cls, mock_compute, db_session_mock, +): + tg = _mock_table_group(project_code="demo") + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_for_column.return_value = _mock_profile_result( + pii_flag="B/CONTACT/Email", + top_freq_values="| alice@example.com | 5\n| bob@example.com | 3", + ) + mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + mock_compute.return_value = ProjectPermissions( + memberships={"demo": "role_a"}, + permission="catalog", + username="test_user", + ) + # Add view_pii to the matrix for this test by patching the role-lookup. + with patch("testgen.mcp.permissions.PluginHook") as hook_mock: + hook_mock.instance.return_value.rbac.get_roles_with_permission.return_value = ["role_a"] + from testgen.mcp.tools.profiling import get_column_frequent_values + result = get_column_frequent_values(str(uuid4()), "customers", "email") + + assert "alice@example.com" in result + assert PII_REDACTED not in result + + +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_frequent_values_high_cardinality_fallback( + mock_tg_cls, mock_pr_cls, mock_run_cls, db_session_mock, +): + tg = _mock_table_group() + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_for_column.return_value = _mock_profile_result( + top_freq_values=None, distinct_value_ct=10000, + ) + mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + + from testgen.mcp.tools.profiling import get_column_frequent_values + result = get_column_frequent_values(str(uuid4()), "customers", "customer_id") + + assert "Frequency data not available" in result + assert "10000" in result + + +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_frequent_values_missing_profile_raises_not_accessible( + mock_tg_cls, mock_pr_cls, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_pr_cls.get_for_column.return_value = None + + from testgen.mcp.tools.profiling import get_column_frequent_values + with pytest.raises(MCPResourceNotAccessible, match="Column profile"): + get_column_frequent_values(str(uuid4()), "customers", "ghost") + + +# ---------------------------------------------------------------------- +# get_column_patterns +# ---------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_patterns_happy_path(mock_tg_cls, mock_pr_cls, mock_run_cls, db_session_mock): + tg = _mock_table_group() + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_for_column.return_value = _mock_profile_result( + general_type="A", + top_patterns="326 | Aaaaaa | 176 | AAA", + ) + mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + + from testgen.mcp.tools.profiling import get_column_patterns + result = get_column_patterns(str(uuid4()), "customers", "country") + + assert "Character patterns: customers.country" in result + assert "Aaaaaa" in result and "AAA" in result + assert "Top patterns" in result + + +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_patterns_non_string_column_fallback( + mock_tg_cls, mock_pr_cls, mock_run_cls, db_session_mock, +): + tg = _mock_table_group() + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_for_column.return_value = _mock_profile_result( + general_type="N", + top_patterns=None, + ) + mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + + from testgen.mcp.tools.profiling import get_column_patterns + result = get_column_patterns(str(uuid4()), "products", "price") + + assert "column is not a string type" in result + + +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_patterns_high_cardinality_fallback( + mock_tg_cls, mock_pr_cls, mock_run_cls, db_session_mock, +): + tg = _mock_table_group() + mock_tg_cls.get.return_value = tg + mock_pr_cls.get_for_column.return_value = _mock_profile_result( + general_type="A", + top_patterns=None, + distinct_value_ct=9999, + ) + mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + + from testgen.mcp.tools.profiling import get_column_patterns + result = get_column_patterns(str(uuid4()), "customers", "address") + + assert "Pattern data not available" in result + assert "9999" in result + + +# ---------------------------------------------------------------------- +# search_columns +# ---------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +def test_search_columns_no_scope_uses_all_accessible_projects(mock_dcc_cls, db_session_mock): + mock_dcc_cls.search_by_name.return_value = ([], 0) + mock_dcc_cls.summarize_matches_by_project.return_value = [] + + from testgen.mcp.tools.profiling import search_columns + result = search_columns("email") + + assert "all accessible projects" in result or "No columns matching" in result + + +@patch.object(DataColumnChars, "search_by_name") +@patch("testgen.mcp.tools.common.TableGroup") +def test_search_columns_table_group_scope_passes_tg_id_clause(mock_tg_cls, mock_method, db_session_mock): + tg = _mock_table_group() + mock_tg_cls.get.return_value = tg + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import search_columns + search_columns("email", table_group_id=str(uuid4())) + + sql = " ".join( + str(c.compile(compile_kwargs={"literal_binds": True})) for c in mock_method.call_args[0] + ) + assert "table_groups_id" in sql + + +def test_search_columns_rejects_both_scopes_passed(db_session_mock): + from testgen.mcp.tools.profiling import search_columns + with pytest.raises(MCPUserError, match="not both"): + search_columns("email", project_code="demo", table_group_id=str(uuid4())) + + +def test_search_columns_empty_pattern_rejected(db_session_mock): + from testgen.mcp.tools.profiling import search_columns + with pytest.raises(MCPUserError, match="`pattern` is required"): + search_columns(" ") + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +def test_search_columns_renders_per_project_summary_when_no_scope(mock_dcc_cls, db_session_mock): + hit = MagicMock() + hit.project_code = "DEFAULT" + hit.table_groups_name = "default" + hit.schema_name = "demo" + hit.table_name = "d_ebike_suppliers" + hit.column_name = "contact_email" + mock_dcc_cls.search_by_name.return_value = ([hit], 1) + mock_dcc_cls.summarize_matches_by_project.return_value = [("DEFAULT", 1), ("DEMO_2", 0)] + + from testgen.mcp.tools.profiling import search_columns + result = search_columns("email") + + assert "Matches by project" in result + assert "DEFAULT" in result + + +@patch("testgen.mcp.tools.profiling.DataColumnChars") +@patch("testgen.mcp.tools.common.TableGroup") +def test_search_columns_table_group_scope_skips_per_project_summary( + mock_tg_cls, mock_dcc_cls, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + hit = MagicMock() + hit.project_code = "demo" + hit.table_groups_name = "default" + hit.schema_name = "demo" + hit.table_name = "customers" + hit.column_name = "email" + mock_dcc_cls.search_by_name.return_value = ([hit], 1) + + from testgen.mcp.tools.profiling import search_columns + result = search_columns("email", table_group_id=str(uuid4())) + + assert "Matches by project" not in result + mock_dcc_cls.summarize_matches_by_project.assert_not_called() From c44ec729c6df01d2296047d3a5c88de06e47abd0 Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Mon, 11 May 2026 16:40:27 -0400 Subject: [PATCH 10/58] fix(monitors): freshness-gate Volume_Trend/Metric_Trend prediction Stairstep volume/metric series (e.g. weekly-refreshed tables) collapsed the SARIMAX SE estimate, so every refresh tripped the band as a false positive. When the same table has an active Freshness_Trend monitor, prediction now fits SARIMAX on the value series filtered to fingerprint-change runs and emits a baseline. Execution dual-branches: band check when Freshness fired this run, `<> baseline` during stale periods (catches silent writes that the band check alone would miss). Falls back to a raw-history SARIMAX fit when the filtered fit cannot run. Drops the post-resample SARIMAX minimum to 8 and lifts the suite-level predict_min_lookback to TestThresholdsPrediction.run() so it gates the raw history once for every branch. Surfaces the gated baseline as "Threshold" on the sparkline tooltip alongside the lower/upper bound. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../commands/queries/execute_tests_query.py | 59 ++++- testgen/commands/run_test_execution.py | 7 +- .../commands/test_thresholds_prediction.py | 110 +++++++-- testgen/common/freshness_service.py | 19 ++ testgen/common/time_series_service.py | 2 +- .../get_current_freshness_signal.sql | 12 + .../get_freshness_fingerprint_events.sql | 18 ++ .../get_historical_test_results.sql | 3 + .../js/pages/table_monitoring_trends.js | 5 + .../js/components/monitoring_sparkline.js | 4 + testgen/ui/views/monitors_dashboard.py | 1 + .../queries/test_execute_tests_query.py | 119 +++++++++ .../commands/test_thresholds_prediction.py | 225 +++++++++++++++++- tests/unit/common/conftest.py | 15 +- tests/unit/common/test_freshness_scenarios.py | 6 +- tests/unit/common/test_freshness_service.py | 35 +++ tests/unit/common/test_time_series_service.py | 17 -- 17 files changed, 599 insertions(+), 58 deletions(-) create mode 100644 testgen/template/execution/get_current_freshness_signal.sql create mode 100644 testgen/template/prediction/get_freshness_fingerprint_events.sql diff --git a/testgen/commands/queries/execute_tests_query.py b/testgen/commands/queries/execute_tests_query.py index e81d95a9..794ffcdb 100644 --- a/testgen/commands/queries/execute_tests_query.py +++ b/testgen/commands/queries/execute_tests_query.py @@ -8,9 +8,15 @@ from testgen.common import read_template_sql_file from testgen.common.clean_sql import concat_columns -from testgen.common.database.database_service import get_flavor_service, get_tg_schema, replace_params +from testgen.common.database.database_service import ( + fetch_dict_from_db, + get_flavor_service, + get_tg_schema, + replace_params, +) from testgen.common.freshness_service import ( count_excluded_minutes, + get_freshness_gated_baseline, get_schedule_params, is_excluded_day, resolve_holiday_dates, @@ -264,6 +270,52 @@ def __init__(self, connection: Connection, table_group: TableGroup, test_suite: test_suite.holiday_codes_list, pd.DatetimeIndex([datetime(self.run_date.year - 1, 1, 1), datetime(self.run_date.year + 1, 12, 31)]), ) + # Cache of (schema, table) -> "did Freshness_Trend detect a fingerprint change in + # this run?". True / False / None (no Freshness_Trend result). Populated lazily + # per table; reused across all Volume/Metric defs for the same table. + self._freshness_changed_cache: dict[tuple[str, str], bool | None] = {} + + def _freshness_changed_for_table(self, test_def: TestExecutionDef) -> bool | None: + """Did Freshness_Trend detect a fingerprint change for the test's table in this run? + + Reads the latest Freshness_Trend result_signal written during the current run. + Freshness_Trend emits `result_signal = '0'` when the table fingerprint differs + from the previous run's baseline (i.e., an update was detected). Any other value + (the interval since last update) means no change. + + Returns True / False per the signal, or None if no Freshness_Trend result exists + for this table in this run. + """ + cache_key = (test_def.schema_name, test_def.table_name) + if cache_key in self._freshness_changed_cache: + return self._freshness_changed_cache[cache_key] + + rows = fetch_dict_from_db(*self._get_query("get_current_freshness_signal.sql", test_def=test_def)) + changed: bool | None = None + if rows and rows[0].get("result_signal") is not None: + changed = str(rows[0]["result_signal"]) == "0" + self._freshness_changed_cache[cache_key] = changed + return changed + + def _resolve_cat_operator_and_condition(self, test_def: TestExecutionDef) -> tuple[str, str]: + """Pick the operator / condition pair to feed into build_cat_expressions. + + For Volume_Trend / Metric_Trend with freshness-gating enabled, when Freshness_Trend + detected no change in this run the table is in a "stale period" — the measure must + equal baseline_value, and any deviation is a silent-write anomaly. In that case, + override the test definition's `NOT BETWEEN` band check with a strict equality + check against baseline_value. All other cases (band check, refresh detected, + non-monitor test types, or no Freshness_Trend result) keep the test definition's + own operator and condition. + """ + if ( + test_def.test_type in ("Volume_Trend", "Metric_Trend") + and (baseline := get_freshness_gated_baseline(test_def.prediction)) is not None + and self._freshness_changed_for_table(test_def) is False + ): + return "<>", str(baseline) + + return test_def.test_operator, test_def.test_condition def _get_input_parameters(self, test_def: TestExecutionDef) -> str: return "; ".join( @@ -464,15 +516,16 @@ def aggregate_cat_tests( # Don't recalculate expressions if it was already done before if not td.measure_expression or not td.condition_expression: params = self._get_params(td) + operator, condition_template = self._resolve_cat_operator_and_condition(td) measure = replace_params(td.measure, params) measure = replace_templated_functions(measure, self.flavor) - condition = replace_params(td.test_condition, params) + condition = replace_params(condition_template, params) condition = replace_templated_functions(condition, self.flavor) td.measure_expression, td.condition_expression = build_cat_expressions( measure=measure, - test_operator=td.test_operator, + test_operator=operator, test_condition=condition, history_calculation=td.history_calculation, lower_tolerance=td.lower_tolerance, diff --git a/testgen/commands/run_test_execution.py b/testgen/commands/run_test_execution.py index 568a463d..65d759df 100644 --- a/testgen/commands/run_test_execution.py +++ b/testgen/commands/run_test_execution.py @@ -110,7 +110,12 @@ def run_test_execution( "METADATA": partial(_run_tests, sql_generator, "METADATA"), "CAT": partial(_run_cat_tests, sql_generator), } - # Run metadata tests last so that results for other tests are available to them + # Run order: QUERY → CAT → METADATA is load-bearing for monitor suites. + # Freshness_Trend (QUERY) writes the table fingerprint to test_results, which + # Volume_Trend and Metric_Trend (both CAT) read at execution time to apply + # freshness-gated thresholds (see TestExecutionSQL._get_params). Metadata tests + # stay last so results for other tests are available to them. Do not reorder + # without revisiting freshness-gating in the SQL templates and exec params. for run_type in ["QUERY", "CAT", "METADATA"]: if (run_test_defs := [td for td in valid_test_defs if td.run_type == run_type]): run_functions[run_type](run_test_defs, save_progress=not test_suite.is_monitor) diff --git a/testgen/commands/test_thresholds_prediction.py b/testgen/commands/test_thresholds_prediction.py index 7f6617ee..7c501e98 100644 --- a/testgen/commands/test_thresholds_prediction.py +++ b/testgen/commands/test_thresholds_prediction.py @@ -87,11 +87,18 @@ def run(self) -> None: df = to_dataframe(test_results, coerce_float=True) grouped_dfs = df.groupby("test_definition_id", group_keys=False) + # Freshness update events are fetched as secondary data only when the suite + # is a monitor — Volume_Trend / Metric_Trend in monitor suites couple to the + # Freshness_Trend signal to avoid stairstep false positives. + freshness_updates_by_table: dict[tuple[str, str], list[str]] = ( + self._fetch_freshness_updates_by_table() if self.test_suite.is_monitor else {} + ) + LOG.info(f"Training prediction models for tests: {len(grouped_dfs)}") prediction_results = [] for test_def_id, group in grouped_dfs: test_type = group["test_type"].iloc[0] - history = group[["test_time", "result_signal"]] + history = group[["test_time", "result_signal", "test_run_id"]] history = history.set_index("test_time") test_prediction = [ @@ -99,30 +106,42 @@ def run(self) -> None: test_def_id, to_sql_timestamp(self.run_date), ] - if test_type == "Freshness_Trend": + # Skip prediction if history is smaller than configured lookback + if len(history) < (self.test_suite.predict_min_lookback or 1): + test_prediction.extend([None, None, None, None]) + elif test_type == "Freshness_Trend": lower, upper, staleness, prediction = compute_freshness_threshold( history, sensitivity=self.test_suite.predict_sensitivity or PredictSensitivity.medium, - min_lookback=self.test_suite.predict_min_lookback or 1, exclude_weekends=self.test_suite.predict_exclude_weekends, holiday_codes=self.test_suite.holiday_codes_list, schedule_tz=self.tz, ) test_prediction.extend([lower, upper, staleness, prediction]) - else: - lower, upper, prediction = compute_sarimax_threshold( + elif test_type in ("Volume_Trend", "Metric_Trend"): + table_key = (group["schema_name"].iloc[0], group["table_name"].iloc[0]) + lower, upper, baseline, prediction = compute_volume_or_metric_threshold( history, + freshness_updates=freshness_updates_by_table.get(table_key, []), sensitivity=self.test_suite.predict_sensitivity or PredictSensitivity.medium, - min_lookback=self.test_suite.predict_min_lookback or 1, exclude_weekends=self.test_suite.predict_exclude_weekends, holiday_codes=self.test_suite.holiday_codes_list, schedule_tz=self.tz, ) if test_type == "Volume_Trend": - if lower is not None: + if lower is not None: lower = max(lower, 0.0) if upper is not None: upper = max(upper, 0.0) + test_prediction.extend([lower, upper, baseline, prediction]) + else: + lower, upper, prediction = compute_sarimax_threshold( + history, + sensitivity=self.test_suite.predict_sensitivity or PredictSensitivity.medium, + exclude_weekends=self.test_suite.predict_exclude_weekends, + holiday_codes=self.test_suite.holiday_codes_list, + schedule_tz=self.tz, + ) test_prediction.extend([lower, upper, None, prediction]) prediction_results.append(test_prediction) @@ -149,11 +168,21 @@ def _get_query( query = replace_params(query, params) return query, params + def _fetch_freshness_updates_by_table( + self, + ) -> dict[tuple[str, str], list[str]]: + """Fetch test_run_ids of Freshness_Trend fingerprint changes, indexed by table.""" + rows = fetch_dict_from_db(*self._get_query("get_freshness_fingerprint_events.sql")) + events_by_table: dict[tuple[str, str], list[str]] = {} + for row in rows: + key = (row["schema_name"], row["table_name"]) + events_by_table.setdefault(key, []).append(str(row["test_run_id"])) + return events_by_table + def compute_freshness_threshold( history: pd.DataFrame, sensitivity: PredictSensitivity, - min_lookback: int = 1, exclude_weekends: bool = False, holiday_codes: list[str] | None = None, schedule_tz: str | None = None, @@ -163,9 +192,6 @@ def compute_freshness_threshold( Returns (lower, upper, staleness_threshold, prediction_json) in business minutes, or (None, None, None, None) if not enough data. """ - if len(history) < min_lookback: - return None, None, None, None - upper_percentile, floor_multiplier, lower_percentile = FRESHNESS_THRESHOLD_MAP[sensitivity] staleness_factor = STALENESS_FACTOR_MAP[sensitivity] @@ -264,7 +290,6 @@ def compute_sarimax_threshold( history: pd.DataFrame, sensitivity: PredictSensitivity, num_forecast: int = NUM_FORECAST, - min_lookback: int = 1, exclude_weekends: bool = False, holiday_codes: list[str] | None = None, schedule_tz: str | None = None, @@ -273,12 +298,9 @@ def compute_sarimax_threshold( Returns (lower, upper, forecast_json) or (None, None, None) if insufficient data. """ - if len(history) < min_lookback: - return None, None, None - try: forecast = get_sarimax_forecast( - history, + history[["result_signal"]], # SARIMAX only consumes result_signal - drop other columns num_forecast=num_forecast, exclude_weekends=exclude_weekends, holiday_codes=holiday_codes, @@ -305,3 +327,59 @@ def compute_sarimax_threshold( return float(lower_tolerance), float(upper_tolerance), forecast.to_json() except NotEnoughData: return None, None, None + + +def compute_volume_or_metric_threshold( + history: pd.DataFrame, + freshness_updates: list[str], + sensitivity: PredictSensitivity, + num_forecast: int = NUM_FORECAST, + exclude_weekends: bool = False, + holiday_codes: list[str] | None = None, + schedule_tz: str | None = None, +) -> tuple[float | None, float | None, float | None, str | None]: + """SARIMAX threshold for Volume_Trend / Metric_Trend with freshness-gating. + + First, attempts a SARIMAX fit on the value series filtered only to points with freshness updates. + This avoids the "stairstep" false-positive shape where inter-change plateaus collapse the SE estimate. + The returned prediction JSON is augmented with `freshness_gated` and `baseline_value` so + that test execution can apply dual-branch evaluation. + + If the filtered fit fails for any reason, falls back to fit SARIMAX on + the raw value series and emits a prediction JSON without the freshness-gating markers. + + `history` is expected to have a `test_run_id` column alongside `result_signal`, and to be + indexed by `test_time`. `freshness_updates` is the list of run identifiers where + Freshness_Trend detected a fingerprint change. + """ + filtered_history = history.loc[history["test_run_id"].astype(str).isin(freshness_updates)] + lower, upper, prediction = compute_sarimax_threshold( + filtered_history, + sensitivity=sensitivity, + num_forecast=num_forecast, + exclude_weekends=exclude_weekends, + holiday_codes=holiday_codes, + schedule_tz=schedule_tz, + ) + if prediction is not None: + # Pull the baseline value from the most-recent filtered row. + last_update_ts = filtered_history.index.max() + baseline_value = filtered_history.loc[last_update_ts, "result_signal"] + baseline_value = float(baseline_value) if not pd.isna(baseline_value) else None + prediction_dict = json.loads(prediction) + prediction_dict.update({ + "freshness_gated": True, + "baseline_value": baseline_value, + }) + prediction = json.dumps(prediction_dict) + return lower, upper, baseline_value, prediction + + lower, upper, prediction = compute_sarimax_threshold( + history, + sensitivity=sensitivity, + num_forecast=num_forecast, + exclude_weekends=exclude_weekends, + holiday_codes=holiday_codes, + schedule_tz=schedule_tz, + ) + return lower, upper, None, prediction diff --git a/testgen/common/freshness_service.py b/testgen/common/freshness_service.py index f7810787..b87e5a93 100644 --- a/testgen/common/freshness_service.py +++ b/testgen/common/freshness_service.py @@ -143,6 +143,25 @@ def get_schedule_params(prediction: dict | str | None) -> ScheduleParams: return ScheduleParams(excluded_days=excluded_days, window_start=window_start, window_end=window_end) +def get_freshness_gated_baseline(prediction: dict | str | None) -> float | None: + """Extract the freshness-gated baseline value from a Volume_Trend / Metric_Trend + prediction JSON. + + The baseline is the test value at the most recent detected freshness update. Returns + None when the prediction is missing, empty, does not have freshness-gating enabled, + or has no baseline value recorded. + """ + if not prediction: + return None + parsed = prediction if isinstance(prediction, dict) else json.loads(prediction) + if not parsed.get("freshness_gated"): + return None + baseline_value = parsed.get("baseline_value") + if baseline_value is None: + return None + return float(baseline_value) + + def is_excluded_day( dt: pd.Timestamp, exclude_weekends: bool, diff --git a/testgen/common/time_series_service.py b/testgen/common/time_series_service.py index 7aca697a..aeabb180 100644 --- a/testgen/common/time_series_service.py +++ b/testgen/common/time_series_service.py @@ -10,7 +10,7 @@ # This is a heuristic minimum to get a reasonable prediction # Not a hard limit of the model -MIN_TRAIN_VALUES = 20 +MIN_TRAIN_VALUES = 8 class NotEnoughData(ValueError): diff --git a/testgen/template/execution/get_current_freshness_signal.sql b/testgen/template/execution/get_current_freshness_signal.sql new file mode 100644 index 00000000..962d6f5c --- /dev/null +++ b/testgen/template/execution/get_current_freshness_signal.sql @@ -0,0 +1,12 @@ +-- Latest Freshness_Trend result_signal for a given table within the current run. Used +-- by Volume_Trend / Metric_Trend execution to detect whether the table has been updated +-- this run: result_signal = '0' means fingerprint changed, any other value means no +-- change (signal carries the interval-since-last-update). +SELECT result_signal +FROM test_results +WHERE test_run_id = :TEST_RUN_ID ::UUID + AND test_type = 'Freshness_Trend' + AND schema_name = :SCHEMA_NAME + AND table_name = :TABLE_NAME +ORDER BY test_time DESC +LIMIT 1; diff --git a/testgen/template/prediction/get_freshness_fingerprint_events.sql b/testgen/template/prediction/get_freshness_fingerprint_events.sql new file mode 100644 index 00000000..ad892d7e --- /dev/null +++ b/testgen/template/prediction/get_freshness_fingerprint_events.sql @@ -0,0 +1,18 @@ +-- Fingerprint-change events from Freshness_Trend tests, used as secondary data for +-- freshness-gated SARIMAX prediction of Volume_Trend / Metric_Trend. +-- +-- Returns one row per detected fingerprint change (result_signal = '0'), ordered by +-- (schema, table, time). +SELECT DISTINCT + d.schema_name, + d.table_name, + r.test_run_id, + r.test_time +FROM test_results r +JOIN test_definitions d ON d.id = r.test_definition_id +WHERE r.test_suite_id = :TEST_SUITE_ID + AND d.test_suite_id = :TEST_SUITE_ID + AND d.test_type = 'Freshness_Trend' + AND d.test_active = 'Y' + AND r.result_signal = '0' +ORDER BY d.schema_name, d.table_name, r.test_time; diff --git a/testgen/template/prediction/get_historical_test_results.sql b/testgen/template/prediction/get_historical_test_results.sql index 800ecc10..cbf91a32 100644 --- a/testgen/template/prediction/get_historical_test_results.sql +++ b/testgen/template/prediction/get_historical_test_results.sql @@ -12,7 +12,10 @@ WITH filtered_defs AS ( AND history_calculation = 'PREDICT' ) SELECT r.test_definition_id, + r.test_run_id, d.test_type, + d.schema_name, + d.table_name, r.test_time, CASE WHEN r.result_signal ~ '^-?[0-9]*\.?[0-9]+$' THEN r.result_signal::NUMERIC diff --git a/testgen/ui/components/frontend/js/pages/table_monitoring_trends.js b/testgen/ui/components/frontend/js/pages/table_monitoring_trends.js index 8aa0891f..8e0c86cd 100644 --- a/testgen/ui/components/frontend/js/pages/table_monitoring_trends.js +++ b/testgen/ui/components/frontend/js/pages/table_monitoring_trends.js @@ -455,6 +455,10 @@ const ChartsSection = (props, { schemaChartSelection, getDataStructureLogs }) => originalUpperTolerance: e.upper_tolerance != undefined ? parseInt(e.upper_tolerance) : undefined, + // Freshness-gated baseline (only present on gated runs). + originalThreshold: e.threshold_value != undefined + ? parseFloat(e.threshold_value) + : undefined, label: 'Row count', isAnomaly: e.is_anomaly, isTraining: e.is_training, @@ -490,6 +494,7 @@ const ChartsSection = (props, { schemaChartSelection, getDataStructureLogs }) => originalY: e.value, originalLowerTolerance: e.lower_tolerance, originalUpperTolerance: e.upper_tolerance, + originalThreshold: e.threshold_value, isAnomaly: e.is_anomaly, isTraining: e.is_training, isPending: e.is_pending, diff --git a/testgen/ui/static/js/components/monitoring_sparkline.js b/testgen/ui/static/js/components/monitoring_sparkline.js index f81251e9..4eb2c61a 100644 --- a/testgen/ui/static/js/components/monitoring_sparkline.js +++ b/testgen/ui/static/js/components/monitoring_sparkline.js @@ -24,6 +24,7 @@ * @property {boolean?} isPending * @property {number?} lowerTolerance * @property {number?} upperTolerance + * @property {number?} originalThreshold * * @typedef PredictionPoint * @type {Object} @@ -253,6 +254,9 @@ const MonitoringSparklineChartTooltip = (point) => { {class: 'flex-column'}, span({class: 'text-left mb-1'}, formatTimestamp(point.originalX)), span({class: 'text-left text-small'}, `${point.label || 'Value'}: ${formatNumber(point.originalY)}`), + point.originalThreshold != undefined + ? span({class: 'text-left text-small'}, `Baseline: ${formatNumber(point.originalThreshold)}`) + : '', point.lowerTolerance != undefined ? span({class: 'text-left text-small'}, `Lower bound: ${formatNumber(point.originalLowerTolerance)}`) : '', diff --git a/testgen/ui/views/monitors_dashboard.py b/testgen/ui/views/monitors_dashboard.py index 3007c173..5b33f7b0 100644 --- a/testgen/ui/views/monitors_dashboard.py +++ b/testgen/ui/views/monitors_dashboard.py @@ -946,6 +946,7 @@ def get_monitor_events_for_table(test_suite_id: str, table_name: str, lookback_m "is_pending": not bool(event["result_id"]), "lower_tolerance": params.get("lower_tolerance") if params.get("lower_tolerance") else None, "upper_tolerance": params.get("upper_tolerance") if params.get("upper_tolerance") else None, + "threshold_value": params.get("threshold_value") if params.get("threshold_value") else None, }) return { diff --git a/tests/unit/commands/queries/test_execute_tests_query.py b/tests/unit/commands/queries/test_execute_tests_query.py index 71fb66dd..4839f99b 100644 --- a/tests/unit/commands/queries/test_execute_tests_query.py +++ b/tests/unit/commands/queries/test_execute_tests_query.py @@ -1,10 +1,12 @@ from datetime import UTC, datetime +from unittest.mock import patch from uuid import uuid4 import pytest from testgen.commands.queries.execute_tests_query import ( TestExecutionDef, + TestExecutionSQL, build_cat_expressions, group_cat_tests, parse_cat_results, @@ -359,3 +361,120 @@ def test_parse_result_code_negative_one(): rows = parse_cat_results(results, test_defs, uuid4(), uuid4(), datetime.now(UTC), _make_input_params_fn()) assert rows[0][10] == "-1" + + +# --- TestExecutionSQL freshness-gating helpers --- + + +def _make_execution_sql() -> TestExecutionSQL: + """Build a minimal TestExecutionSQL instance for testing instance methods. + + Bypasses __init__ (which hits the database) and sets only the attributes the + freshness-gating methods touch. + """ + instance = TestExecutionSQL.__new__(TestExecutionSQL) + instance._freshness_changed_cache = {} + return instance + + +FRESHNESS_FETCH_TARGET = "testgen.commands.queries.execute_tests_query.fetch_dict_from_db" + + +@patch.object(TestExecutionSQL, "_get_query", return_value=("SELECT ...", {})) +@patch(FRESHNESS_FETCH_TARGET) +def test_freshness_changed_true_when_result_signal_is_zero(mock_fetch, _mock_query): + mock_fetch.return_value = [{"result_signal": "0"}] + instance = _make_execution_sql() + assert instance._freshness_changed_for_table(_make_td()) is True + + +@patch.object(TestExecutionSQL, "_get_query", return_value=("SELECT ...", {})) +@patch(FRESHNESS_FETCH_TARGET) +def test_freshness_changed_false_when_result_signal_is_interval(mock_fetch, _mock_query): + mock_fetch.return_value = [{"result_signal": "1440"}] + instance = _make_execution_sql() + assert instance._freshness_changed_for_table(_make_td()) is False + + +@patch.object(TestExecutionSQL, "_get_query", return_value=("SELECT ...", {})) +@patch(FRESHNESS_FETCH_TARGET) +def test_freshness_changed_none_when_no_result(mock_fetch, _mock_query): + mock_fetch.return_value = [] + instance = _make_execution_sql() + assert instance._freshness_changed_for_table(_make_td()) is None + + +@patch.object(TestExecutionSQL, "_get_query", return_value=("SELECT ...", {})) +@patch(FRESHNESS_FETCH_TARGET) +def test_freshness_changed_cached_per_table(mock_fetch, _mock_query): + """Multiple Volume/Metric defs on the same table should not re-query.""" + mock_fetch.return_value = [{"result_signal": "0"}] + instance = _make_execution_sql() + instance._freshness_changed_for_table(_make_td(schema_name="s", table_name="t")) + instance._freshness_changed_for_table(_make_td(schema_name="s", table_name="t")) + assert mock_fetch.call_count == 1 + + +def test_resolve_cat_returns_definition_default_for_non_monitor_types(): + instance = _make_execution_sql() + td = _make_td(test_type="Alpha_Trunc", test_operator=">=", test_condition="50") + operator, condition = instance._resolve_cat_operator_and_condition(td) + assert (operator, condition) == (">=", "50") + + +def test_resolve_cat_returns_definition_default_when_no_gating(): + """Volume_Trend / Metric_Trend with no freshness_gated flag in prediction → band check.""" + instance = _make_execution_sql() + td = _make_td( + test_type="Volume_Trend", + test_operator="NOT BETWEEN", + test_condition="{LOWER_TOLERANCE} AND {UPPER_TOLERANCE}", + prediction={"mean": {"123": 220.0}}, # no freshness_gated + ) + operator, condition = instance._resolve_cat_operator_and_condition(td) + assert operator == "NOT BETWEEN" + assert condition == "{LOWER_TOLERANCE} AND {UPPER_TOLERANCE}" + + +@patch.object(TestExecutionSQL, "_freshness_changed_for_table", return_value=False) +def test_resolve_cat_stale_period_overrides_to_baseline_equality(_mock_changed): + """When freshness-gated and Freshness signal != '0' (no change), override to <> baseline.""" + instance = _make_execution_sql() + td = _make_td( + test_type="Volume_Trend", + test_operator="NOT BETWEEN", + test_condition="{LOWER_TOLERANCE} AND {UPPER_TOLERANCE}", + prediction={"freshness_gated": True, "baseline_value": 220.0}, + ) + assert instance._resolve_cat_operator_and_condition(td) == ("<>", "220.0") + + +@patch.object(TestExecutionSQL, "_freshness_changed_for_table", return_value=True) +def test_resolve_cat_refresh_period_uses_band_check(_mock_changed): + """When freshness-gated and Freshness fired this run, fall through to band check.""" + instance = _make_execution_sql() + td = _make_td( + test_type="Volume_Trend", + test_operator="NOT BETWEEN", + test_condition="{LOWER_TOLERANCE} AND {UPPER_TOLERANCE}", + prediction={"freshness_gated": True, "baseline_value": 220.0}, + ) + operator, condition = instance._resolve_cat_operator_and_condition(td) + assert operator == "NOT BETWEEN" + assert condition == "{LOWER_TOLERANCE} AND {UPPER_TOLERANCE}" + + +@patch.object(TestExecutionSQL, "_freshness_changed_for_table", return_value=None) +def test_resolve_cat_no_freshness_result_uses_band_check(_mock_changed): + """When no Freshness_Trend has run for this table this run, fall back to band check.""" + instance = _make_execution_sql() + td = _make_td( + test_type="Metric_Trend", + test_operator="NOT BETWEEN", + test_condition="{LOWER_TOLERANCE} AND {UPPER_TOLERANCE}", + prediction={"freshness_gated": True, "baseline_value": 5.5}, + ) + operator, condition = instance._resolve_cat_operator_and_condition(td) + assert operator == "NOT BETWEEN" + + diff --git a/tests/unit/commands/test_thresholds_prediction.py b/tests/unit/commands/test_thresholds_prediction.py index f9df4592..37891a8c 100644 --- a/tests/unit/commands/test_thresholds_prediction.py +++ b/tests/unit/commands/test_thresholds_prediction.py @@ -1,5 +1,6 @@ import json -from unittest.mock import patch +from datetime import datetime +from unittest.mock import MagicMock, patch import pandas as pd import pytest @@ -8,7 +9,9 @@ from testgen.commands.test_thresholds_prediction import ( T_DISTRIBUTION_THRESHOLD, Z_SCORE_MAP, + TestThresholdsPrediction, compute_sarimax_threshold, + compute_volume_or_metric_threshold, ) from testgen.common.models.test_suite import PredictSensitivity from testgen.common.time_series_service import NotEnoughData @@ -16,6 +19,19 @@ pytestmark = pytest.mark.unit +def _make_prediction_instance(suite_id: str = "suite-xyz") -> TestThresholdsPrediction: + """Build a minimal TestThresholdsPrediction instance for testing instance methods. + + Bypasses __init__ (which queries the database) and sets just the attributes that + _get_query and methods under test rely on. + """ + instance = TestThresholdsPrediction.__new__(TestThresholdsPrediction) + instance.test_suite = MagicMock(id=suite_id) + instance.run_date = datetime(2026, 1, 1) + instance.tz = None + return instance + + def _make_history(n: int, value: float = 100.0) -> pd.DataFrame: """Build a minimal history DataFrame with n data points.""" dates = pd.date_range("2025-01-01", periods=n, freq="D") @@ -31,17 +47,6 @@ def _make_forecast(mean_values: list[float], se_values: list[float]) -> pd.DataF MOCK_TARGET = "testgen.commands.test_thresholds_prediction.get_sarimax_forecast" -# --- min_lookback guard --- - - -def test_below_min_lookback_returns_none(): - history = _make_history(3) - lower, upper, prediction = compute_sarimax_threshold(history, PredictSensitivity.medium, min_lookback=5) - assert lower is None - assert upper is None - assert prediction is None - - # --- Normal tolerance calculation (large sample, z-scores used directly) --- @@ -196,3 +201,199 @@ def test_all_z_score_columns_added_to_forecast(mock_forecast): for key in Z_SCORE_MAP: col = f"{key[0]}|{key[1].value}" assert col in forecast.columns + + +# --- TestThresholdsPrediction._fetch_freshness_updates_by_table --- +# +# Method fetches via _get_query → get_freshness_fingerprint_events.sql, which returns +# rows pre-filtered to fingerprint-change events and ordered by (schema, table, time). +# Tests mock the fetch and verify the indexing. + +FETCH_TARGET = "testgen.commands.test_thresholds_prediction.fetch_dict_from_db" + + +@patch(FETCH_TARGET) +def test_fetch_freshness_events_groups_by_table(mock_fetch): + mock_fetch.return_value = [ + {"schema_name": "s", "table_name": "t1", "test_run_id": "run_1"}, + {"schema_name": "s", "table_name": "t1", "test_run_id": "run_2"}, + {"schema_name": "s", "table_name": "t2", "test_run_id": "run_3"}, + ] + instance = _make_prediction_instance() + events = instance._fetch_freshness_updates_by_table() + assert set(events.keys()) == {("s", "t1"), ("s", "t2")} + assert events[("s", "t1")] == ["run_1", "run_2"] + assert events[("s", "t2")] == ["run_3"] + + +@patch(FETCH_TARGET) +def test_fetch_freshness_events_preserves_input_order(mock_fetch): + """SQL returns rows ordered by (schema, table, test_time); the method trusts that + order rather than re-sorting.""" + mock_fetch.return_value = [ + {"schema_name": "s", "table_name": "t", "test_run_id": "run_a"}, + {"schema_name": "s", "table_name": "t", "test_run_id": "run_b"}, + {"schema_name": "s", "table_name": "t", "test_run_id": "run_c"}, + ] + instance = _make_prediction_instance() + events = instance._fetch_freshness_updates_by_table() + assert events[("s", "t")] == ["run_a", "run_b", "run_c"] + + +@patch(FETCH_TARGET) +def test_fetch_freshness_events_coerces_run_id_to_str(mock_fetch): + """test_run_id can come back as a UUID object — must be cast to str for downstream + .isin() matching against the str-cast Volume/Metric test_run_id column.""" + from uuid import UUID as _UUID + rid = _UUID("12345678-1234-5678-1234-567812345678") + mock_fetch.return_value = [ + {"schema_name": "s", "table_name": "t", "test_run_id": rid}, + ] + instance = _make_prediction_instance() + events = instance._fetch_freshness_updates_by_table() + assert events[("s", "t")] == [str(rid)] + + +@patch(FETCH_TARGET) +def test_fetch_freshness_events_empty_result(mock_fetch): + mock_fetch.return_value = [] + instance = _make_prediction_instance() + assert instance._fetch_freshness_updates_by_table() == {} + + +@patch(FETCH_TARGET) +def test_fetch_freshness_events_passes_suite_id_through_get_query(mock_fetch): + """Reuses self._get_query, which substitutes TEST_SUITE_ID from self.test_suite.id.""" + mock_fetch.return_value = [] + instance = _make_prediction_instance(suite_id="suite-xyz") + instance._fetch_freshness_updates_by_table() + _query, params = mock_fetch.call_args.args + assert params["TEST_SUITE_ID"] == "suite-xyz" + + +# --- compute_volume_or_metric_threshold --- + + +def _history_with_run_ids(timestamps: list[str], run_ids: list[str], value: float = 100.0) -> pd.DataFrame: + """Build a Volume/Metric-shaped history: indexed by test_time, with a test_run_id + column matching how `run()` slices the historical-results dataframe per definition.""" + assert len(timestamps) == len(run_ids) + return pd.DataFrame( + {"result_signal": [value] * len(timestamps), "test_run_id": run_ids}, + index=pd.to_datetime(timestamps), + ) + + +@patch(MOCK_TARGET) +def test_freshness_gating_engages_when_filtered_fit_succeeds(mock_forecast): + mock_forecast.return_value = _make_forecast([220.0], [1.0]) + timestamps = [f"2026-01-{day:02d}" for day in range(1, 21)] + run_ids = [f"run_{i:02d}" for i in range(len(timestamps))] + history = _history_with_run_ids(timestamps, run_ids, value=220.0) + freshness_updates = run_ids[:8] + + lower, upper, baseline, prediction = compute_volume_or_metric_threshold( + history, freshness_updates, PredictSensitivity.medium, + ) + + assert lower is not None and upper is not None + assert baseline == 220.0 + assert prediction is not None + parsed = json.loads(prediction) + assert parsed["freshness_gated"] is True + assert parsed["baseline_value"] == 220.0 + + +@patch(MOCK_TARGET) +def test_freshness_gating_falls_back_when_filtered_fit_raises(mock_forecast): + """If SARIMAX fails on the freshness-filtered series (NotEnoughData after resample, + convergence), fall back to fitting on the raw value series and emit a prediction + without the freshness-gating markers.""" + raw_forecast = _make_forecast([220.0], [1.0]) + mock_forecast.side_effect = [NotEnoughData("not enough"), raw_forecast] + timestamps = [f"2026-01-{day:02d}" for day in range(1, 21)] + run_ids = [f"run_{i:02d}" for i in range(len(timestamps))] + history = _history_with_run_ids(timestamps, run_ids, value=220.0) + freshness_updates = run_ids[:5] # any selection — first call is forced to raise + + _, _, baseline, prediction = compute_volume_or_metric_threshold( + history, freshness_updates, PredictSensitivity.medium, + ) + + assert mock_forecast.call_count == 2 # filtered failed, raw retried + assert baseline is None + assert prediction is not None + parsed = json.loads(prediction) + assert "freshness_gated" not in parsed + assert "baseline_value" not in parsed + + +@patch(MOCK_TARGET) +def test_freshness_gating_falls_back_when_no_freshness_events(mock_forecast): + """Empty freshness_updates → filtered history is empty → filtered fit fails → + fall back to fitting on the raw series.""" + # First call (filtered, 0 rows) returns enough that compute_sarimax_threshold trips + # the NaN tolerance path; second call (raw) succeeds. + raw_forecast = _make_forecast([220.0], [1.0]) + mock_forecast.side_effect = [NotEnoughData("not enough"), raw_forecast] + timestamps = [f"2026-01-{day:02d}" for day in range(1, 21)] + run_ids = [f"run_{i:02d}" for i in range(len(timestamps))] + history = _history_with_run_ids(timestamps, run_ids) + + _, _, baseline, prediction = compute_volume_or_metric_threshold( + history, freshness_updates=[], sensitivity=PredictSensitivity.medium, + ) + + assert baseline is None + assert prediction is not None + parsed = json.loads(prediction) + assert "freshness_gated" not in parsed + + +@patch(MOCK_TARGET) +def test_freshness_gating_fits_on_filtered_series(mock_forecast): + """SARIMAX should be fit on the filtered series (one row per freshness change), + not on the raw plateau-laden series. Verified via the length of the dataframe + passed to get_sarimax_forecast on the engaging call.""" + mock_forecast.return_value = _make_forecast([220.0], [1.0]) + timestamps = [f"2026-01-{day:02d}" for day in range(1, 21)] + run_ids = [f"run_{i:02d}" for i in range(len(timestamps))] + history = _history_with_run_ids(timestamps, run_ids, value=220.0) + freshness_updates = run_ids[:8] + + compute_volume_or_metric_threshold( + history, freshness_updates, PredictSensitivity.medium, + ) + + fitted_history = mock_forecast.call_args.args[0] + assert len(fitted_history) == len(freshness_updates) + + +@patch(MOCK_TARGET) +def test_freshness_gating_baseline_from_filtered_when_events_extend_past_history(mock_forecast): + """When freshness_updates includes runs beyond the (retention-trimmed) history window, + baseline_value must come from the most recent filtered row — not from a run that's no + longer in history.""" + mock_forecast.return_value = _make_forecast([220.0], [1.0]) + # History only covers the first 8 days (run_00..run_07) + history_timestamps = [f"2026-01-{day:02d}" for day in range(1, 9)] + history_run_ids = [f"run_{i:02d}" for i in range(8)] + values = [float(i) for i in range(1, 9)] # distinct values so baseline is identifiable + history = pd.DataFrame( + {"result_signal": values, "test_run_id": history_run_ids}, + index=pd.to_datetime(history_timestamps), + ) + # Freshness events for those 8 runs PLUS 3 more that aren't in history (trimmed) + freshness_updates = history_run_ids + [f"run_{i}" for i in range(20, 23)] + + _, _, baseline, prediction = compute_volume_or_metric_threshold( + history, freshness_updates, PredictSensitivity.medium, + ) + + assert baseline == 8.0 + assert prediction is not None + parsed = json.loads(prediction) + assert parsed["freshness_gated"] is True + # baseline_value must be the value at the LAST timestamp present in BOTH history and + # freshness_updates (not freshness_updates[-1] which points past the history window) + assert parsed["baseline_value"] == 8.0 diff --git a/tests/unit/common/conftest.py b/tests/unit/common/conftest.py index 8646de5c..875d147b 100644 --- a/tests/unit/common/conftest.py +++ b/tests/unit/common/conftest.py @@ -134,18 +134,23 @@ def _run_scenario( sensitivity: PredictSensitivity, exclude_weekends: bool = False, tz: str | None = None, + min_lookback: int = 30, ) -> list[ScenarioPoint]: - """Iterate through csv_rows calling compute_freshness_threshold at each step.""" + """Iterate through csv_rows, mirroring the call shape of TestThresholdsPrediction.run(): + a min_lookback guard against the raw history, then compute_freshness_threshold.""" results: list[ScenarioPoint] = [] freshness_last_update: pd.Timestamp | None = None for i, (timestamp, value) in enumerate(csv_rows): history_df = _to_history_df(csv_rows[:i]) - lower, upper, staleness, prediction_json = compute_freshness_threshold( - history_df, sensitivity, min_lookback=30, - exclude_weekends=exclude_weekends, schedule_tz=tz, - ) + if len(history_df) < min_lookback: + lower = upper = staleness = prediction_json = None + else: + lower, upper, staleness, prediction_json = compute_freshness_threshold( + history_df, sensitivity, + exclude_weekends=exclude_weekends, schedule_tz=tz, + ) result_code, result_status = _evaluate_freshness_point( timestamp, value, lower, upper, staleness, prediction_json, diff --git a/tests/unit/common/test_freshness_scenarios.py b/tests/unit/common/test_freshness_scenarios.py index 86111b1a..a20d2a52 100644 --- a/tests/unit/common/test_freshness_scenarios.py +++ b/tests/unit/common/test_freshness_scenarios.py @@ -65,12 +65,12 @@ def results_no_excl(self) -> list[ScenarioPoint]: return _run_scenario(rows, PredictSensitivity.medium, exclude_weekends=False, tz=None) def test_training_exits(self, results_excl: list[ScenarioPoint]) -> None: - """Training should end. First non-training update needs 5 gaps + min_lookback=30 rows.""" + """Training should end once MIN_FRESHNESS_GAPS (5) completed gaps are observed.""" updates = _updates(results_excl) first_non_training = next((i for i, p in enumerate(updates) if p.upper is not None), None) assert first_non_training is not None - # 5 weekday updates = 5 gaps, but min_lookback=30 means ~30 rows needed first - # With 12h obs interval and daily updates, training exits around update 10-14 + # 5 weekday updates yield 5 gaps; with 12h obs interval and daily updates, + # training exits soon after. assert 6 <= first_non_training <= 16 def test_zero_anomalies_excl(self, results_excl: list[ScenarioPoint]) -> None: diff --git a/tests/unit/common/test_freshness_service.py b/tests/unit/common/test_freshness_service.py index f8317413..2797d965 100644 --- a/tests/unit/common/test_freshness_service.py +++ b/tests/unit/common/test_freshness_service.py @@ -15,6 +15,7 @@ detect_active_days, detect_update_window, get_freshness_gap_threshold, + get_freshness_gated_baseline, get_schedule_params, infer_schedule, is_excluded_day, @@ -757,6 +758,40 @@ def test_no_exclusions_for_tentative(self): assert result.window_start is None assert result.window_end is None + +# --------------------------------------------------------------------------- +# get_freshness_gated_baseline Tests +# --------------------------------------------------------------------------- + +class Test_GetFreshnessGatedBaseline: + def test_returns_none_for_none(self): + assert get_freshness_gated_baseline(None) is None + + def test_returns_none_for_empty_string(self): + assert get_freshness_gated_baseline("") is None + + def test_returns_none_when_freshness_gated_absent(self): + assert get_freshness_gated_baseline({"mean": {"123": 100.0}}) is None + + def test_returns_none_when_freshness_gated_false(self): + assert get_freshness_gated_baseline({"freshness_gated": False, "baseline_value": 100.0}) is None + + def test_returns_baseline_when_freshness_gated_true(self): + assert get_freshness_gated_baseline({"freshness_gated": True, "baseline_value": 220.0}) == 220.0 + + def test_parses_from_json_string(self): + pred = json.dumps({"freshness_gated": True, "baseline_value": 5.5}) + assert get_freshness_gated_baseline(pred) == 5.5 + + def test_returns_none_when_baseline_value_missing(self): + assert get_freshness_gated_baseline({"freshness_gated": True}) is None + + def test_baseline_value_coerced_to_float(self): + """JSON may serialize int — must be cast to float for downstream SQL.""" + result = get_freshness_gated_baseline({"freshness_gated": True, "baseline_value": 220}) + assert isinstance(result, float) + assert result == 220.0 + def test_no_window_when_missing(self): pred = {"frequency": "sub_daily", "schedule_stage": "active"} result = get_schedule_params(pred) diff --git a/tests/unit/common/test_time_series_service.py b/tests/unit/common/test_time_series_service.py index 86e2e8b3..c8a5d66a 100644 --- a/tests/unit/common/test_time_series_service.py +++ b/tests/unit/common/test_time_series_service.py @@ -450,23 +450,6 @@ def test_sensitivity_ordering(self): assert upper_high <= upper_med <= upper_low - def test_min_lookback_respected(self): - # 6 updates with sawtooth rows in between — the helper generates many rows - updates = [f"2026-02-{d:02d}T{h:02d}:00" for d, h in [(1, 0), (1, 10), (1, 20), (2, 6), (2, 16), (3, 2)]] - history = _make_freshness_history(updates) - row_count = len(history) - - # With min_lookback at exactly the row count → should produce thresholds - _, upper, _, _ = compute_freshness_threshold(history, PredictSensitivity.medium, min_lookback=row_count) - assert upper is not None - - # With min_lookback above the row count → training mode - lower, upper, staleness, prediction = compute_freshness_threshold(history, PredictSensitivity.medium, min_lookback=row_count + 1) - assert lower is None - assert upper is None - assert staleness is None - assert prediction is None - class Test_AddBusinessMinutes: def test_no_exclusions(self): start = pd.Timestamp("2026-02-09T08:00") # Monday From e5f6ac013ecc69b9e616d947bc9bce1b5cec64e2 Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Wed, 13 May 2026 08:48:56 -0400 Subject: [PATCH 11/58] refactor: centralize /api/v1 prefix in api package router The four API v1 routers each declared `prefix="/api/v1"` at construction. Move the prefix to a single aggregator router exposed by `testgen.api`, so the version prefix lives in one place. Each sub-router now declares only its tags, dependencies, and responses. `server/__init__.py` mounts the aggregator once instead of including four routers individually. OpenAPI paths and tag distribution are unchanged. The unit test that mounted `jobs.router` directly on a `FastAPI()` test app now passes `prefix="/api/v1"` at include time to mirror production wiring; otherwise its `/api/v1/...` requests would 404. Co-Authored-By: Claude Opus 4.7 --- testgen/api/__init__.py | 12 ++++++++++++ testgen/api/app.py | 2 +- testgen/api/jobs.py | 2 +- testgen/api/runs.py | 2 +- testgen/api/test_definitions.py | 1 - testgen/server/__init__.py | 10 ++-------- tests/unit/api/test_jobs.py | 2 +- 7 files changed, 18 insertions(+), 13 deletions(-) diff --git a/testgen/api/__init__.py b/testgen/api/__init__.py index e69de29b..5d82c173 100644 --- a/testgen/api/__init__.py +++ b/testgen/api/__init__.py @@ -0,0 +1,12 @@ +from fastapi import APIRouter + +from testgen.api.app import router as _app_router +from testgen.api.jobs import router as _jobs_router +from testgen.api.runs import router as _runs_router +from testgen.api.test_definitions import router as _test_definitions_router + +router = APIRouter(prefix="/api/v1") +router.include_router(_app_router) +router.include_router(_jobs_router) +router.include_router(_runs_router) +router.include_router(_test_definitions_router) diff --git a/testgen/api/app.py b/testgen/api/app.py index 8111916a..2f81b1fb 100644 --- a/testgen/api/app.py +++ b/testgen/api/app.py @@ -5,7 +5,7 @@ from testgen.api.deps import db_session from testgen.common import version_service -router = APIRouter(prefix="/api/v1", tags=["API"], dependencies=[Depends(db_session)]) +router = APIRouter(tags=["API"], dependencies=[Depends(db_session)]) @router.get("/health") diff --git a/testgen/api/jobs.py b/testgen/api/jobs.py index 7131fdf1..9f70944e 100644 --- a/testgen/api/jobs.py +++ b/testgen/api/jobs.py @@ -19,7 +19,7 @@ 404: {"model": ErrorResponse, "description": "Not found"}, } -router = APIRouter(prefix="/api/v1", tags=["Jobs"], dependencies=[Depends(db_session)], responses=_error_responses) +router = APIRouter(tags=["Jobs"], dependencies=[Depends(db_session)], responses=_error_responses) @router.post( diff --git a/testgen/api/runs.py b/testgen/api/runs.py index 64110721..fcbe0445 100644 --- a/testgen/api/runs.py +++ b/testgen/api/runs.py @@ -26,7 +26,7 @@ 404: {"model": ErrorResponse, "description": "Not found"}, } -router = APIRouter(prefix="/api/v1", tags=["runs"], dependencies=[Depends(db_session)], responses=_error_responses) +router = APIRouter(tags=["runs"], dependencies=[Depends(db_session)], responses=_error_responses) @router.get( diff --git a/testgen/api/test_definitions.py b/testgen/api/test_definitions.py index 6b205cc0..6d2d0e41 100644 --- a/testgen/api/test_definitions.py +++ b/testgen/api/test_definitions.py @@ -21,7 +21,6 @@ } router = APIRouter( - prefix="/api/v1", tags=["Test Definitions"], dependencies=[Depends(db_session)], responses=_error_responses, diff --git a/testgen/server/__init__.py b/testgen/server/__init__.py index 7109c420..d0867e9d 100644 --- a/testgen/server/__init__.py +++ b/testgen/server/__init__.py @@ -17,14 +17,11 @@ if settings.IS_DEBUG: os.environ.setdefault("AUTHLIB_INSECURE_TRANSPORT", "1") -from testgen.api.app import router as api_router -from testgen.api.jobs import router as jobs_router +from testgen.api import router as api_v1_router from testgen.api.oauth.metadata import router as metadata_router from testgen.api.oauth.routes import init_routes from testgen.api.oauth.routes import router as oauth_router from testgen.api.oauth.server import create_authorization_server -from testgen.api.runs import router as runs_router -from testgen.api.test_definitions import router as test_definitions_router from testgen.common import version_service from testgen.common.models import with_database_session from testgen.server.middleware import BodySizeLimitMiddleware, SecurityHeadersMiddleware @@ -124,10 +121,7 @@ def favicon(): app.include_router(metadata_router) app.include_router(oauth_router) - app.include_router(api_router) - app.include_router(jobs_router) - app.include_router(runs_router) - app.include_router(test_definitions_router) + app.include_router(api_v1_router) if settings.MCP_ENABLED: app.mount("", mcp_app) diff --git a/tests/unit/api/test_jobs.py b/tests/unit/api/test_jobs.py index 18d260a4..6af04eb4 100644 --- a/tests/unit/api/test_jobs.py +++ b/tests/unit/api/test_jobs.py @@ -205,7 +205,7 @@ def test_list_jobs_empty_project(mock_je_cls): def _client_with_overrides() -> TestClient: """Build a TestClient that bypasses auth and db_session so query validation runs unimpeded.""" app = FastAPI() - app.include_router(router) + app.include_router(router, prefix="/api/v1") app.dependency_overrides[db_session] = lambda: iter([None]) app.dependency_overrides[get_authorized_user] = lambda: MagicMock(id=uuid4()) return app From 61509f320f4166361dc04498fb11605e341ec060 Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Wed, 13 May 2026 09:48:49 -0400 Subject: [PATCH 12/58] refactor: extract _check_access helper for API resolvers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The entity resolver factories in `api/deps.py` each repeated the pattern "if entity and user has permission on its project, return; else raise 404". Extract this into `_check_access`, where the no-info-leakage security intent (not-found and unauthorized both surface as the same 404) is documented once. `resolve_project_code` keeps its own shape — it has no entity lookup, so it doesn't fit the pattern. Move the four function-local model imports (TableGroup, TestSuite, JobExecution, sqlalchemy.select) to module level. No circular import. Co-Authored-By: Claude Opus 4.7 --- testgen/api/deps.py | 45 ++++++++++++++++++++++----------------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/testgen/api/deps.py b/testgen/api/deps.py index 8daac06f..539bcff2 100644 --- a/testgen/api/deps.py +++ b/testgen/api/deps.py @@ -4,10 +4,14 @@ from fastapi import Depends, HTTPException, Security, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from sqlalchemy import select from testgen.common.auth import authorize_token, decode_jwt_token from testgen.common.models import Session, _current_session_wrapper, get_current_session +from testgen.common.models.job_execution import JobExecution from testgen.common.models.project_membership import ProjectMembership +from testgen.common.models.table_group import TableGroup +from testgen.common.models.test_suite import TestSuite from testgen.common.models.user import User from testgen.utils.plugins import PluginHook @@ -73,14 +77,25 @@ def has_project_permission(user: User, project_code: str, permission: str) -> bo # --- Resolver dependency factories --- # Each factory takes a permission string and returns Depends(). The entity ID -# comes from a URL path parameter (FastAPI resolves it natively). -# Entity not found and insufficient permission both raise the same 404 -# with a stable code/message — no variation that could leak the cause. +# comes from a URL path parameter (FastAPI resolves it natively, including +# UUID validation that yields a 422 for malformed inputs). _require_user = Depends(get_authorized_user) _not_found = api_error(404, "not_found", "Not found") +def _check_access(entity, user: User, permission: str): + """Return ``entity`` if the user has ``permission`` on its project, else raise 404. + + Entity-not-found and insufficient-permission both surface as the same 404 + with a stable code/message — no variation that could leak the cause to an + unauthorized caller. + """ + if entity and has_project_permission(user, entity.project_code, permission): + return entity + raise _not_found + + def resolve_project_code(permission: str): """Verify the user has ``permission`` on the project identified by ``project_code`` path param.""" def dependency(project_code: str, user: User = _require_user) -> str: @@ -92,23 +107,15 @@ def dependency(project_code: str, user: User = _require_user) -> str: def resolve_table_group(permission: str): """Resolve a TableGroup by ``table_group_id`` path param and verify project permission.""" - from testgen.common.models.table_group import TableGroup - def dependency(table_group_id: UUID, user: User = _require_user) -> TableGroup: - if (table_group := TableGroup.get(table_group_id)) and has_project_permission(user, table_group.project_code, permission): - return table_group - raise _not_found + return _check_access(TableGroup.get(table_group_id), user, permission) return Depends(dependency) def resolve_test_suite(permission: str): """Resolve a non-monitor TestSuite by ``test_suite_id`` path param and verify project permission.""" - from testgen.common.models.test_suite import TestSuite - def dependency(test_suite_id: UUID, user: User = _require_user) -> TestSuite: - if (test_suite := TestSuite.get_regular(test_suite_id)) and has_project_permission(user, test_suite.project_code, permission): - return test_suite - raise _not_found + return _check_access(TestSuite.get_regular(test_suite_id), user, permission) return Depends(dependency) @@ -116,21 +123,13 @@ def resolve_job(permission: str, *extra_filters): """Resolve a JobExecution by ``job_id`` path param and verify project permission. Internally-submitted jobs (source='system') are never exposed via the API. - Extra ORM clauses are appended to the WHERE clause, e.g. to restrict by job_key. - Mismatches surface as the same 404 — no information leakage. + Extra ORM clauses are appended to the WHERE clause to restrict by job_key. """ - from sqlalchemy import select - - from testgen.common.models.job_execution import JobExecution - def dependency(job_id: UUID, user: User = _require_user) -> JobExecution: query = select(JobExecution).where( JobExecution.id == job_id, JobExecution.source != "system", *extra_filters, ) - job = get_current_session().scalars(query).first() - if job and has_project_permission(user, job.project_code, permission): - return job - raise _not_found + return _check_access(get_current_session().scalars(query).first(), user, permission) return Depends(dependency) From 158331d532b7436e8a6f05641b6250b3f52cd052 Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Wed, 13 May 2026 10:07:58 -0400 Subject: [PATCH 13/58] refactor: drop vestigial args column from job_executions and job_schedules MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `args` was never read by the dispatch path — `exec_job` invokes `handler(**je.kwargs)` and ignores args entirely. `JobExecution.submit()` doesn't accept it, and every UI/plugin call site was passing `args=[]`. Remove the column from both tables and the call sites that supplied it. The job_schedules uniqueness constraint included args; the migration resolves the auto-generated PG constraint name dynamically (truncation varies), drops the column, and recreates the constraint as `job_schedules_uniq` over `(project_code, key, kwargs, cron_expr, cron_tz)`. `get_job_arguments` in the schedule dialog returned `tuple[list, dict]`; collapse it to `dict` now that the list half is dead. Co-Authored-By: Claude Opus 4.7 --- testgen/common/models/job_execution.py | 5 ++-- testgen/common/models/scheduler.py | 1 - testgen/scheduler/cli_scheduler.py | 2 -- .../030_initialize_new_schema_structure.sql | 4 +-- .../dbupgrade/0188_incremental_upgrade.sql | 26 +++++++++++++++++++ testgen/ui/views/connections.py | 2 -- testgen/ui/views/dialogs/manage_schedules.py | 6 ++--- testgen/ui/views/monitors_dashboard.py | 1 - testgen/ui/views/profiling_runs.py | 4 +-- testgen/ui/views/table_groups.py | 2 -- testgen/ui/views/test_runs.py | 4 +-- tests/unit/scheduler/test_scheduler_cli.py | 3 +-- tests/unit/scheduler/test_scheduler_poll.py | 4 --- 13 files changed, 36 insertions(+), 28 deletions(-) create mode 100644 testgen/template/dbupgrade/0188_incremental_upgrade.sql diff --git a/testgen/common/models/job_execution.py b/testgen/common/models/job_execution.py index 24a23cbc..54d804c8 100644 --- a/testgen/common/models/job_execution.py +++ b/testgen/common/models/job_execution.py @@ -36,9 +36,8 @@ class JobExecution(Base): id: UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, default=uuid4) job_key: str = Column(String(100), nullable=False) - # args and kwargs are internal dispatch details passed to the job handler. - # Do not query or filter on them — external code should not depend on their structure. - args: list[Any] = Column(postgresql.JSONB, nullable=False, default=list, server_default=text("'[]'::jsonb")) + # kwargs is the internal dispatch payload passed to the job handler. + # Do not query or filter on it — external code should not depend on its structure. kwargs: dict[str, Any] = Column(postgresql.JSONB, nullable=False, default=dict, server_default=text("'{}'::jsonb")) source: str = Column(String(20), nullable=False) status: str = Column(String(20), nullable=False, default=JobStatus.PENDING, server_default=text("'pending'")) diff --git a/testgen/common/models/scheduler.py b/testgen/common/models/scheduler.py index 825766a8..dda95383 100644 --- a/testgen/common/models/scheduler.py +++ b/testgen/common/models/scheduler.py @@ -26,7 +26,6 @@ class JobSchedule(Base): project_code: str = Column(String) key: str = Column(String, nullable=False) - args: list[Any] = Column(postgresql.JSONB, nullable=False, default=[]) kwargs: dict[str, Any] = Column(postgresql.JSONB, nullable=False, default={}) cron_expr: str = Column(String, nullable=False) cron_tz: str = Column(String, nullable=False) diff --git a/testgen/scheduler/cli_scheduler.py b/testgen/scheduler/cli_scheduler.py index ddc3c9ec..d572b036 100644 --- a/testgen/scheduler/cli_scheduler.py +++ b/testgen/scheduler/cli_scheduler.py @@ -22,7 +22,6 @@ @dataclass class CliJob(Job): key: str - args: Iterable[Any] kwargs: dict[str, Any] project_code: str | None = field(default=None) job_schedule_id: UUID | None = field(default=None) @@ -58,7 +57,6 @@ def get_jobs(self) -> Iterable[CliJob]: cron_tz=job_model.cron_tz, delayed_policy=DelayedPolicy.SKIP, key=job_model.key, - args=job_model.args, kwargs=job_model.kwargs, project_code=job_model.project_code, job_schedule_id=job_model.id, diff --git a/testgen/template/dbsetup/030_initialize_new_schema_structure.sql b/testgen/template/dbsetup/030_initialize_new_schema_structure.sql index 1e7217df..e77aa9c1 100644 --- a/testgen/template/dbsetup/030_initialize_new_schema_structure.sql +++ b/testgen/template/dbsetup/030_initialize_new_schema_structure.sql @@ -1066,12 +1066,11 @@ CREATE TABLE job_schedules ( id UUID NOT NULL PRIMARY KEY, project_code VARCHAR(30) NOT NULL, key VARCHAR(100) NOT NULL, - args JSONB NOT NULL, kwargs JSONB NOT NULL, cron_expr VARCHAR(50) NOT NULL, cron_tz VARCHAR(30) NOT NULL, active BOOLEAN DEFAULT TRUE, - UNIQUE (project_code, key, args, kwargs, cron_expr, cron_tz) + UNIQUE (project_code, key, kwargs, cron_expr, cron_tz) ); CREATE INDEX job_schedules_idx ON job_schedules (project_code, key); @@ -1079,7 +1078,6 @@ CREATE INDEX job_schedules_idx ON job_schedules (project_code, key); CREATE TABLE job_executions ( id UUID NOT NULL DEFAULT gen_random_uuid() PRIMARY KEY, job_key VARCHAR(100) NOT NULL, - args JSONB NOT NULL DEFAULT '[]'::jsonb, kwargs JSONB NOT NULL DEFAULT '{}'::jsonb, source VARCHAR(20) NOT NULL, status VARCHAR(20) NOT NULL DEFAULT 'pending', diff --git a/testgen/template/dbupgrade/0188_incremental_upgrade.sql b/testgen/template/dbupgrade/0188_incremental_upgrade.sql new file mode 100644 index 00000000..90292af4 --- /dev/null +++ b/testgen/template/dbupgrade/0188_incremental_upgrade.sql @@ -0,0 +1,26 @@ +SET SEARCH_PATH TO {SCHEMA_NAME}; + +-- Drop the unused `args` column from job_schedules and job_executions. +-- It's vestigial: exec_job dispatches via handler(**je.kwargs); no path reads args. +-- The job_schedules UNIQUE constraint includes args, so resolve and drop it dynamically +-- (the auto-generated PG constraint name varies with truncation). + +DO $$ +DECLARE c_name TEXT; +BEGIN + SELECT conname INTO c_name + FROM pg_constraint + WHERE conrelid = 'job_schedules'::regclass + AND contype = 'u' + AND conkey @> ARRAY[(SELECT attnum FROM pg_attribute WHERE attrelid = 'job_schedules'::regclass AND attname = 'args')]; + IF c_name IS NOT NULL THEN + EXECUTE format('ALTER TABLE job_schedules DROP CONSTRAINT %I', c_name); + END IF; +END $$; + +ALTER TABLE job_schedules DROP COLUMN args; + +ALTER TABLE job_schedules + ADD CONSTRAINT job_schedules_uniq UNIQUE (project_code, key, kwargs, cron_expr, cron_tz); + +ALTER TABLE job_executions DROP COLUMN args; diff --git a/testgen/ui/views/connections.py b/testgen/ui/views/connections.py index 35fa58cf..83b548e5 100644 --- a/testgen/ui/views/connections.py +++ b/testgen/ui/views/connections.py @@ -426,7 +426,6 @@ def on_close_clicked(_params: dict) -> None: key=RUN_TESTS_JOB_KEY, cron_expr=standard_test_suite_data["schedule"], cron_tz=standard_test_suite_data["timezone"], - args=[], kwargs={"test_suite_id": str(standard_test_suite.id)}, ).save() @@ -458,7 +457,6 @@ def on_close_clicked(_params: dict) -> None: key=RUN_MONITORS_JOB_KEY, cron_expr=monitor_test_suite_data.get("schedule"), cron_tz=monitor_test_suite_data.get("timezone"), - args=[], kwargs={"test_suite_id": str(monitor_test_suite.id)}, ).save() diff --git a/testgen/ui/views/dialogs/manage_schedules.py b/testgen/ui/views/dialogs/manage_schedules.py index 4caf4c59..5aeeff5d 100644 --- a/testgen/ui/views/dialogs/manage_schedules.py +++ b/testgen/ui/views/dialogs/manage_schedules.py @@ -33,7 +33,7 @@ def get_arg_value(self, job): def get_arg_value_options(self) -> list[dict[str, str]]: raise NotImplementedError - def get_job_arguments(self, arg_value: str) -> tuple[list[Any], dict[str, Any]]: + def get_job_arguments(self, arg_value: str) -> dict[str, Any]: raise NotImplementedError def build_data(self) -> dict: @@ -98,15 +98,13 @@ def on_add(self, payload: dict) -> None: is_form_valid = bool(arg_value) and bool(cron_tz) and bool(cron_expr) if is_form_valid: cron_obj = cron_converter.Cron(cron_expr) - args, kwargs = self.get_job_arguments(arg_value) sched_model = JobSchedule( project_code=self.project_code, key=self.job_key, cron_expr=cron_obj.to_string(), cron_tz=cron_tz, active=True, - args=args, - kwargs=kwargs, + kwargs=self.get_job_arguments(arg_value), ) with_database_session(sched_model.save)() st.session_state[RESULT_KEY] = {"success": True, "message": "Schedule added"} diff --git a/testgen/ui/views/monitors_dashboard.py b/testgen/ui/views/monitors_dashboard.py index 3007c173..d6c2afb1 100644 --- a/testgen/ui/views/monitors_dashboard.py +++ b/testgen/ui/views/monitors_dashboard.py @@ -638,7 +638,6 @@ def on_save_settings_clicked(payload: dict) -> None: new_schedule = JobSchedule( project_code=table_group.project_code, key=RUN_MONITORS_JOB_KEY, - args=[], kwargs={"test_suite_id": str(monitor_suite.id)}, **new_schedule_config, ) diff --git a/testgen/ui/views/profiling_runs.py b/testgen/ui/views/profiling_runs.py index e612d6e6..aab0d91c 100644 --- a/testgen/ui/views/profiling_runs.py +++ b/testgen/ui/views/profiling_runs.py @@ -235,8 +235,8 @@ def get_arg_value_options(self) -> list[dict[str, str]]: for table_group in self.table_groups ] - def get_job_arguments(self, arg_value: str) -> tuple[list[typing.Any], dict[str, typing.Any]]: - return [], {"table_group_id": str(arg_value)} + def get_job_arguments(self, arg_value: str) -> dict[str, typing.Any]: + return {"table_group_id": str(arg_value)} class ProfilingRunNotificationSettingsDialog(NotificationSettingsDialogBase): diff --git a/testgen/ui/views/table_groups.py b/testgen/ui/views/table_groups.py index ef93062c..eef3b7d8 100644 --- a/testgen/ui/views/table_groups.py +++ b/testgen/ui/views/table_groups.py @@ -403,7 +403,6 @@ def on_close_clicked(_params: dict) -> None: key=RUN_TESTS_JOB_KEY, cron_expr=standard_test_suite_data["schedule"], cron_tz=standard_test_suite_data["timezone"], - args=[], kwargs={"test_suite_id": str(standard_test_suite.id)}, ).save() @@ -435,7 +434,6 @@ def on_close_clicked(_params: dict) -> None: key=RUN_MONITORS_JOB_KEY, cron_expr=monitor_test_suite_data.get("schedule"), cron_tz=monitor_test_suite_data.get("timezone"), - args=[], kwargs={"test_suite_id": str(monitor_test_suite.id)}, ).save() diff --git a/testgen/ui/views/test_runs.py b/testgen/ui/views/test_runs.py index b53a0d48..c78cf24c 100644 --- a/testgen/ui/views/test_runs.py +++ b/testgen/ui/views/test_runs.py @@ -281,8 +281,8 @@ def get_arg_value_options(self) -> list[dict[str, str]]: for test_suite in self.test_suites ] - def get_job_arguments(self, arg_value: str) -> tuple[list[typing.Any], dict[str, typing.Any]]: - return [], {"test_suite_id": str(arg_value)} + def get_job_arguments(self, arg_value: str) -> dict[str, typing.Any]: + return {"test_suite_id": str(arg_value)} @with_database_session diff --git a/tests/unit/scheduler/test_scheduler_cli.py b/tests/unit/scheduler/test_scheduler_cli.py index ce2acec3..d4008250 100644 --- a/tests/unit/scheduler/test_scheduler_cli.py +++ b/tests/unit/scheduler/test_scheduler_cli.py @@ -54,7 +54,6 @@ def job_data(): "cron_expr": "*/5 9-17 * * *", "cron_tz": "UTC", "key": "test-job", - "args": ["a"], "kwargs": {"b": "c"}, } @@ -76,7 +75,7 @@ def test_get_jobs(scheduler_instance, db_jobs, job_sched): assert len(jobs) == 1 assert isinstance(jobs[0], CliJob) - for attr in ("cron_expr", "cron_tz", "key", "args", "kwargs"): + for attr in ("cron_expr", "cron_tz", "key", "kwargs"): assert getattr(jobs[0], attr) == getattr(job_sched, attr), f"Attribute '{attr}' does not match" diff --git a/tests/unit/scheduler/test_scheduler_poll.py b/tests/unit/scheduler/test_scheduler_poll.py index 20b3bfeb..e1900d01 100644 --- a/tests/unit/scheduler/test_scheduler_poll.py +++ b/tests/unit/scheduler/test_scheduler_poll.py @@ -36,7 +36,6 @@ def job_exec(): return JobExecution( id=uuid4(), job_key="run-tests", - args=[], kwargs={"test_suite_id": "suite-123"}, source="scheduler", status="claimed", @@ -73,7 +72,6 @@ def test_dispatch_unknown_job_key(scheduler_instance, mock_session): job_exec = JobExecution( id=uuid4(), job_key="nonexistent", - args=[], kwargs={}, source="ui", status="claimed", @@ -220,7 +218,6 @@ def test_poll_loop_routes_cancel_requested(scheduler_instance, mock_session): cancel_job = JobExecution( id=uuid4(), job_key="run-tests", - args=[], kwargs={}, source="ui", status=JobStatus.CANCEL_REQUESTED, @@ -257,7 +254,6 @@ def test_start_job_submits_execution(scheduler_instance, mock_session): cron_tz="UTC", delayed_policy=DelayedPolicy.SKIP, key="run-profile", - args=[], kwargs={"table_group_id": "tg-123"}, job_schedule_id=schedule_id, ) From 453203bc9f3c6a6e0711e6a0d49d691921082c38 Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Wed, 13 May 2026 10:56:33 -0400 Subject: [PATCH 14/58] refactor: consolidate cross-cutting enums into common.enums MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move six enums that cross subsystem boundaries into ``testgen/common/enums.py`` so the model, API, MCP, scheduler, and UI layers share a single source of truth: - ``JobKey``, ``JobSource`` (were in ``api/schemas.py``) - ``JobStatus`` (was in ``common/models/job_execution.py``) - ``Disposition``, ``IssueLikelihood``, ``PiiRisk`` (were in ``common/models/hygiene_issue.py``) Every call site is updated to import from ``common.enums`` directly — no re-exports. ``api/schemas.py`` and the model files import their own enums back from the new home. While here, fix ``source="user"`` in the project-settings recalculate trigger. ``user`` was never in the ``JobSource`` enum, was used in exactly one place, and is semantically indistinguishable from ``ui`` (a logged-in user triggering a job from a UI page). Coerce to ``ui`` so the audit label matches every other UI-initiated job. No DB backfill — the API filter for surfaced sources will be tightened separately. Co-Authored-By: Claude Opus 4.7 --- testgen/api/jobs.py | 5 +- testgen/api/schemas.py | 18 +------ testgen/commands/exec_job.py | 3 +- testgen/commands/job_registry.py | 3 +- testgen/commands/job_runner.py | 3 +- testgen/commands/run_quick_start.py | 3 +- testgen/common/enums.py | 53 +++++++++++++++++++++ testgen/common/models/hygiene_issue.py | 24 +--------- testgen/common/models/job_execution.py | 12 +---- testgen/common/models/profiling_run.py | 3 +- testgen/common/models/test_run.py | 3 +- testgen/mcp/tools/common.py | 5 +- testgen/mcp/tools/execution.py | 2 +- testgen/mcp/tools/hygiene_issues.py | 3 +- testgen/scheduler/cli_scheduler.py | 3 +- testgen/ui/views/profiling_runs.py | 3 +- testgen/ui/views/project_settings.py | 2 +- testgen/ui/views/test_runs.py | 3 +- tests/unit/commands/test_job_runner.py | 3 +- tests/unit/mcp/test_tools_common.py | 3 +- tests/unit/mcp/test_tools_hygiene_issues.py | 3 +- tests/unit/mcp/test_tools_profiling.py | 2 +- tests/unit/mcp/test_tools_test_runs.py | 2 +- tests/unit/scheduler/test_scheduler_poll.py | 3 +- tests/unit/ui/test_project_settings.py | 4 +- 25 files changed, 94 insertions(+), 77 deletions(-) diff --git a/testgen/api/jobs.py b/testgen/api/jobs.py index 9f70944e..35ab4232 100644 --- a/testgen/api/jobs.py +++ b/testgen/api/jobs.py @@ -10,8 +10,9 @@ resolve_table_group, resolve_test_suite, ) -from testgen.api.schemas import ErrorResponse, JobKey, JobListResponse, JobResponse, JobSource, JobSubmittedResponse -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.api.schemas import ErrorResponse, JobListResponse, JobResponse, JobSubmittedResponse +from testgen.common.enums import JobKey, JobSource, JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.common.models.table_group import TableGroup from testgen.common.models.test_suite import TestSuite diff --git a/testgen/api/schemas.py b/testgen/api/schemas.py index 753152d8..2543f227 100644 --- a/testgen/api/schemas.py +++ b/testgen/api/schemas.py @@ -6,27 +6,11 @@ from pydantic import BaseModel, field_validator -from testgen.common.models.job_execution import JobStatus +from testgen.common.enums import JobKey, JobSource, JobStatus # --- Jobs --- -class JobKey(StrEnum): - run_profile = "run-profile" - run_tests = "run-tests" - run_monitors = "run-monitors" - run_test_generation = "run-test-generation" - - -class JobSource(StrEnum): - api = "api" - ui = "ui" - scheduler = "scheduler" - mcp = "mcp" - cli = "cli" - backfill = "backfill" - - class JobSubmittedResponse(BaseModel): """Returned on 202 Accepted after successful job submission.""" diff --git a/testgen/commands/exec_job.py b/testgen/commands/exec_job.py index 4f63a494..e18de23b 100644 --- a/testgen/commands/exec_job.py +++ b/testgen/commands/exec_job.py @@ -11,9 +11,10 @@ from uuid import UUID from testgen.commands.job_registry import JOB_DISPATCH, run_final_callbacks +from testgen.common.enums import JobStatus from testgen.common.job_context import JobContext, job_context from testgen.common.models import database_session -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.utils import get_exception_message LOG = logging.getLogger("testgen") diff --git a/testgen/commands/job_registry.py b/testgen/commands/job_registry.py index 45d5bfe7..b0fd233f 100644 --- a/testgen/commands/job_registry.py +++ b/testgen/commands/job_registry.py @@ -20,8 +20,9 @@ from testgen.commands.run_score_update import run_score_update from testgen.commands.run_test_execution import run_test_execution from testgen.commands.test_generation import run_test_generation +from testgen.common.enums import JobStatus from testgen.common.models import database_session -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.common.models.profiling_run import ProfilingRun from testgen.common.models.test_run import TestRun from testgen.common.notifications.monitor_run import send_monitor_notifications diff --git a/testgen/commands/job_runner.py b/testgen/commands/job_runner.py index 37a96dc5..00ee523a 100644 --- a/testgen/commands/job_runner.py +++ b/testgen/commands/job_runner.py @@ -11,8 +11,9 @@ from sqlalchemy import select from testgen.commands.exec_job import FINAL_STATUSES, POLL_INTERVAL +from testgen.common.enums import JobStatus from testgen.common.models import database_session, get_current_session -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.common.models.profiling_run import ProfilingRun from testgen.common.models.test_run import TestRun diff --git a/testgen/commands/run_quick_start.py b/testgen/commands/run_quick_start.py index e7a9a84d..0c1d63e0 100644 --- a/testgen/commands/run_quick_start.py +++ b/testgen/commands/run_quick_start.py @@ -19,9 +19,10 @@ set_target_db_params, ) from testgen.common.database.flavor.flavor_service import ConnectionParams +from testgen.common.enums import JobStatus from testgen.common.job_context import JobContext, job_context from testgen.common.models import database_session, with_database_session -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.common.models.scores import ScoreDefinition from testgen.common.models.settings import PersistedSetting from testgen.common.models.table_group import TableGroup diff --git a/testgen/common/enums.py b/testgen/common/enums.py index 94d08a37..73804e33 100644 --- a/testgen/common/enums.py +++ b/testgen/common/enums.py @@ -27,3 +27,56 @@ class ImpactDimension(StrEnum): CONFORMANCE = "Conformance" REGULARITY = "Regularity" USABILITY = "Usability" + + +class JobKey(StrEnum): + """``job_key`` column values for ``job_executions`` and ``job_schedules``.""" + run_profile = "run-profile" + run_tests = "run-tests" + run_monitors = "run-monitors" + run_test_generation = "run-test-generation" + + +class JobSource(StrEnum): + """``source`` column values for ``job_executions``. Identifies which surface + submitted the job — API client, UI, scheduler, MCP tool, CLI, or backfill.""" + api = "api" + ui = "ui" + scheduler = "scheduler" + mcp = "mcp" + cli = "cli" + backfill = "backfill" + + +class JobStatus(StrEnum): + """``status`` column values for ``job_executions``. Lifecycle states; see + ``job_execution.py`` for the transition rules.""" + PENDING = "pending" + CLAIMED = "claimed" + RUNNING = "running" + COMPLETED = "completed" + ERROR = "error" + CANCEL_REQUESTED = "cancel_requested" + CANCELED = "canceled" + + +class Disposition(StrEnum): + """Stored disposition values for ``profile_anomaly_results.disposition`` and + ``test_results.disposition``. The user-facing label for ``INACTIVE`` is "Muted".""" + CONFIRMED = "Confirmed" + DISMISSED = "Dismissed" + INACTIVE = "Inactive" + + +class IssueLikelihood(StrEnum): + """Stored ``profile_anomaly_types.issue_likelihood`` values.""" + DEFINITE = "Definite" + LIKELY = "Likely" + POSSIBLE = "Possible" + POTENTIAL_PII = "Potential PII" + + +class PiiRisk(StrEnum): + """Risk level extracted from PII issue ``detail`` strings via ``priority`` hybrid.""" + HIGH = "High" + MODERATE = "Moderate" diff --git a/testgen/common/models/hygiene_issue.py b/testgen/common/models/hygiene_issue.py index 497c8180..c7683479 100644 --- a/testgen/common/models/hygiene_issue.py +++ b/testgen/common/models/hygiene_issue.py @@ -2,7 +2,6 @@ from collections.abc import Iterable from dataclasses import dataclass from datetime import datetime -from enum import StrEnum from typing import Self from uuid import UUID, uuid4 @@ -12,6 +11,7 @@ from sqlalchemy.orm import aliased, relationship from sqlalchemy.sql.functions import func +from testgen.common.enums import Disposition from testgen.common.models import Base, get_current_session from testgen.common.models.entity import Entity from testgen.common.models.job_execution import JobExecution @@ -22,28 +22,6 @@ PII_RISK_RE = re.compile(r"Risk: (MODERATE|HIGH),") -class Disposition(StrEnum): - """Stored disposition values for ``profile_anomaly_results.disposition`` and - ``test_results.disposition``. The user-facing label for ``INACTIVE`` is "Muted".""" - CONFIRMED = "Confirmed" - DISMISSED = "Dismissed" - INACTIVE = "Inactive" - - -class IssueLikelihood(StrEnum): - """Stored ``profile_anomaly_types.issue_likelihood`` values.""" - DEFINITE = "Definite" - LIKELY = "Likely" - POSSIBLE = "Possible" - POTENTIAL_PII = "Potential PII" - - -class PiiRisk(StrEnum): - """Risk level extracted from PII issue ``detail`` strings via ``priority`` hybrid.""" - HIGH = "High" - MODERATE = "Moderate" - - @dataclass class IssueLikelihoodCounts: """Counts of hygiene issues by likelihood category, with dismissed/inactive separated.""" diff --git a/testgen/common/models/job_execution.py b/testgen/common/models/job_execution.py index 54d804c8..3f33e829 100644 --- a/testgen/common/models/job_execution.py +++ b/testgen/common/models/job_execution.py @@ -1,27 +1,17 @@ import logging from datetime import UTC, datetime -from enum import StrEnum from typing import Any, ClassVar, Self from uuid import UUID, uuid4 from sqlalchemy import Column, String, Text, case, func, select, text, update from sqlalchemy.dialects import postgresql +from testgen.common.enums import JobStatus from testgen.common.models import Base, get_current_session LOG = logging.getLogger("testgen") -class JobStatus(StrEnum): - PENDING = "pending" - CLAIMED = "claimed" - RUNNING = "running" - COMPLETED = "completed" - ERROR = "error" - CANCEL_REQUESTED = "cancel_requested" - CANCELED = "canceled" - - _VALID_TRANSITIONS: dict[JobStatus, frozenset[JobStatus]] = { JobStatus.PENDING: frozenset({JobStatus.CLAIMED, JobStatus.CANCEL_REQUESTED}), JobStatus.CLAIMED: frozenset({JobStatus.RUNNING, JobStatus.ERROR, JobStatus.CANCEL_REQUESTED}), diff --git a/testgen/common/models/profiling_run.py b/testgen/common/models/profiling_run.py index cbdf98b1..2dc05f5d 100644 --- a/testgen/common/models/profiling_run.py +++ b/testgen/common/models/profiling_run.py @@ -11,10 +11,11 @@ from sqlalchemy.orm.attributes import flag_modified from sqlalchemy.sql.expression import case +from testgen.common.enums import JobStatus from testgen.common.models import get_current_session from testgen.common.models.connection import Connection from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.common.models.profile_result import ProfileResult from testgen.common.models.project import Project from testgen.common.models.table_group import TableGroup diff --git a/testgen/common/models/test_run.py b/testgen/common/models/test_run.py index c00c6c6b..a001f6ce 100644 --- a/testgen/common/models/test_run.py +++ b/testgen/common/models/test_run.py @@ -9,10 +9,11 @@ from sqlalchemy.orm.attributes import flag_modified from sqlalchemy.sql.expression import case +from testgen.common.enums import JobStatus from testgen.common.models import get_current_session from testgen.common.models.connection import Connection from testgen.common.models.entity import Entity, EntityMinimal -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.common.models.project import Project from testgen.common.models.table_group import TableGroup from testgen.common.models.test_result import TestResult, TestResultStatus diff --git a/testgen/mcp/tools/common.py b/testgen/mcp/tools/common.py index 78794976..70418587 100644 --- a/testgen/mcp/tools/common.py +++ b/testgen/mcp/tools/common.py @@ -5,10 +5,9 @@ from sqlalchemy import select from testgen.common.date_service import parse_since -from testgen.common.enums import ImpactDimension, QualityDimension +from testgen.common.enums import Disposition, ImpactDimension, IssueLikelihood, JobStatus, PiiRisk, QualityDimension from testgen.common.models import get_current_session -from testgen.common.models.hygiene_issue import Disposition, HygieneIssueType, IssueLikelihood, PiiRisk -from testgen.common.models.job_execution import JobStatus +from testgen.common.models.hygiene_issue import HygieneIssueType from testgen.common.models.profiling_run import ProfilingRun from testgen.common.models.scheduler import JobSchedule from testgen.common.models.table_group import TableGroup diff --git a/testgen/mcp/tools/execution.py b/testgen/mcp/tools/execution.py index 987cbb4d..3885bed5 100644 --- a/testgen/mcp/tools/execution.py +++ b/testgen/mcp/tools/execution.py @@ -2,7 +2,7 @@ from sqlalchemy import select -from testgen.api.schemas import JobKey, JobSource +from testgen.common.enums import JobKey, JobSource from testgen.common.models import get_current_session, with_database_session from testgen.common.models.job_execution import JobExecution from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError diff --git a/testgen/mcp/tools/hygiene_issues.py b/testgen/mcp/tools/hygiene_issues.py index 1a834a4b..19042588 100644 --- a/testgen/mcp/tools/hygiene_issues.py +++ b/testgen/mcp/tools/hygiene_issues.py @@ -4,8 +4,9 @@ from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.functions import func +from testgen.common.enums import Disposition, IssueLikelihood, PiiRisk from testgen.common.models import with_database_session -from testgen.common.models.hygiene_issue import Disposition, HygieneIssue, HygieneIssueType, IssueLikelihood, PiiRisk +from testgen.common.models.hygiene_issue import HygieneIssue, HygieneIssueType from testgen.common.models.job_execution import JobExecution from testgen.common.models.profiling_run import ProfilingRun from testgen.common.models.table_group import TableGroup diff --git a/testgen/scheduler/cli_scheduler.py b/testgen/scheduler/cli_scheduler.py index d572b036..a38090ac 100644 --- a/testgen/scheduler/cli_scheduler.py +++ b/testgen/scheduler/cli_scheduler.py @@ -12,8 +12,9 @@ from testgen import settings from testgen.commands.job_registry import JOB_DISPATCH, run_final_callbacks +from testgen.common.enums import JobStatus from testgen.common.models import database_session, with_database_session -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.common.models.scheduler import JobSchedule from testgen.scheduler.base import DelayedPolicy, Job, Scheduler diff --git a/testgen/ui/views/profiling_runs.py b/testgen/ui/views/profiling_runs.py index aab0d91c..b9f036bf 100644 --- a/testgen/ui/views/profiling_runs.py +++ b/testgen/ui/views/profiling_runs.py @@ -13,8 +13,9 @@ RUN_NOTIFICATIONS_DIALOG_OPEN_COUNT_KEY = "pr:run_notifications_dialog_open_count" import testgen.ui.services.form_service as fm +from testgen.common.enums import JobStatus from testgen.common.models import database_session, get_current_session, with_database_session -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.common.models.notification_settings import ( ProfilingRunNotificationSettings, ProfilingRunNotificationTrigger, diff --git a/testgen/ui/views/project_settings.py b/testgen/ui/views/project_settings.py index 8ed1a45e..39f98abd 100644 --- a/testgen/ui/views/project_settings.py +++ b/testgen/ui/views/project_settings.py @@ -81,7 +81,7 @@ def update_project(self, project_code: str, edited_project: dict) -> None: JobExecution.submit( job_key="recalculate-project-scores", kwargs={"project_code": project_code}, - source="user", + source="ui", project_code=project_code, ) st.toast("Scores will be recalculated in the background.") diff --git a/testgen/ui/views/test_runs.py b/testgen/ui/views/test_runs.py index c78cf24c..408ac797 100644 --- a/testgen/ui/views/test_runs.py +++ b/testgen/ui/views/test_runs.py @@ -6,8 +6,9 @@ import streamlit as st import testgen.ui.services.form_service as fm +from testgen.common.enums import JobStatus from testgen.common.models import database_session, get_current_session, with_database_session -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.common.models.notification_settings import ( TestRunNotificationSettings, TestRunNotificationTrigger, diff --git a/tests/unit/commands/test_job_runner.py b/tests/unit/commands/test_job_runner.py index 3ac4ffa5..dadc66d2 100644 --- a/tests/unit/commands/test_job_runner.py +++ b/tests/unit/commands/test_job_runner.py @@ -4,7 +4,8 @@ import pytest from testgen.commands.job_runner import submit_and_wait -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.enums import JobStatus +from testgen.common.models.job_execution import JobExecution pytestmark = pytest.mark.unit diff --git a/tests/unit/mcp/test_tools_common.py b/tests/unit/mcp/test_tools_common.py index 4a9eb432..77655925 100644 --- a/tests/unit/mcp/test_tools_common.py +++ b/tests/unit/mcp/test_tools_common.py @@ -3,8 +3,7 @@ import pytest -from testgen.common.enums import ImpactDimension, QualityDimension -from testgen.common.models.hygiene_issue import Disposition, IssueLikelihood, PiiRisk +from testgen.common.enums import Disposition, ImpactDimension, IssueLikelihood, PiiRisk, QualityDimension from testgen.common.models.test_result import TestResultStatus from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.tools.common import ( diff --git a/tests/unit/mcp/test_tools_hygiene_issues.py b/tests/unit/mcp/test_tools_hygiene_issues.py index 0a81ce3f..741c9b4b 100644 --- a/tests/unit/mcp/test_tools_hygiene_issues.py +++ b/tests/unit/mcp/test_tools_hygiene_issues.py @@ -5,7 +5,8 @@ import pytest from sqlalchemy.dialects import postgresql -from testgen.common.models.hygiene_issue import Disposition, HygieneIssue, IssueLikelihood +from testgen.common.enums import Disposition, IssueLikelihood +from testgen.common.models.hygiene_issue import HygieneIssue from testgen.common.models.profiling_run import ProfilingRun from testgen.common.pii_masking import PII_REDACTED from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError diff --git a/tests/unit/mcp/test_tools_profiling.py b/tests/unit/mcp/test_tools_profiling.py index 5a4ec01f..c45282ce 100644 --- a/tests/unit/mcp/test_tools_profiling.py +++ b/tests/unit/mcp/test_tools_profiling.py @@ -491,7 +491,7 @@ def test_list_profiling_summaries_inaccessible_tg(mock_tg_cls, db_session_mock): from datetime import UTC -from testgen.common.models.job_execution import JobStatus +from testgen.common.enums import JobStatus _RUN_CREATED = datetime(2026, 4, 1, 10, 0, 0, tzinfo=UTC) _RUN_STARTED = datetime(2026, 4, 1, 10, 0, 5, tzinfo=UTC) diff --git a/tests/unit/mcp/test_tools_test_runs.py b/tests/unit/mcp/test_tools_test_runs.py index 3728ae27..7a71380e 100644 --- a/tests/unit/mcp/test_tools_test_runs.py +++ b/tests/unit/mcp/test_tools_test_runs.py @@ -4,7 +4,7 @@ import pytest -from testgen.common.models.job_execution import JobStatus +from testgen.common.enums import JobStatus from testgen.mcp.exceptions import MCPPermissionDenied, MCPResourceNotAccessible, MCPUserError from testgen.mcp.permissions import ProjectPermissions diff --git a/tests/unit/scheduler/test_scheduler_poll.py b/tests/unit/scheduler/test_scheduler_poll.py index e1900d01..6137c2f2 100644 --- a/tests/unit/scheduler/test_scheduler_poll.py +++ b/tests/unit/scheduler/test_scheduler_poll.py @@ -5,7 +5,8 @@ import pytest from testgen.commands.job_registry import JOB_DISPATCH -from testgen.common.models.job_execution import JobExecution, JobStatus +from testgen.common.enums import JobStatus +from testgen.common.models.job_execution import JobExecution from testgen.scheduler.cli_scheduler import CliScheduler pytestmark = pytest.mark.unit diff --git a/tests/unit/ui/test_project_settings.py b/tests/unit/ui/test_project_settings.py index e38f4488..89c0eeac 100644 --- a/tests/unit/ui/test_project_settings.py +++ b/tests/unit/ui/test_project_settings.py @@ -36,7 +36,7 @@ def test_update_project_submits_recalculate_job_when_weights_toggled_on(mock_ses mock_je.submit.assert_called_once_with( job_key="recalculate-project-scores", kwargs={"project_code": "proj"}, - source="user", + source="ui", project_code="proj", ) @@ -50,7 +50,7 @@ def test_update_project_submits_recalculate_job_when_weights_toggled_off(mock_se mock_je.submit.assert_called_once_with( job_key="recalculate-project-scores", kwargs={"project_code": "proj"}, - source="user", + source="ui", project_code="proj", ) From 207f8e4d02c4da36bcdfa19870142b696ca2b062 Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Wed, 13 May 2026 11:30:29 -0400 Subject: [PATCH 15/58] refactor: gate public job exposure by job_key allowlist, not source MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The ``source`` column on ``job_executions`` is a pure audit label — it records which surface submitted a job. The previous gating filters (``source != 'system'`` in ``api/jobs.py``, ``api/deps.py``, and ``mcp/tools/execution.py``) conflated audit label with visibility rule and were under-inclusive: ``run-monitors`` and ``recalculate-project- scores`` were leaking through to public API listings because they weren't submitted with ``source='system'``. Introduce ``PUBLIC_JOB_KEYS`` in ``common/models/job_execution.py``: the frozenset of ``JobKey`` values that external consumers may see. Replace the two API source filters with ``job_key.in_(PUBLIC_JOB_KEYS)``. Delete the MCP filter entirely — each cancel tool pins ``expected_job_key`` to a public kind, so the source filter was redundant there. Tighten the contract so the audit label can never silently take a typo or stale value: - ``JobExecution.submit(source: JobSource)`` instead of ``str`` - ``JobContext.source: JobSource = JobSource.cli`` instead of the bare ``"CLI"`` default (downstream Mixpanel sites already ``.upper()`` so no telemetry change) - Add ``system`` to ``JobSource`` so the enum stops omitting a value the codebase already writes from the score-rollup callback - Migrate every production call site from string-literal ``source=...`` to the matching ``JobSource.`` Behavior change worth noting: the new filter is strictly stronger. ``run-monitors`` and ``recalculate-project-scores`` are now correctly hidden from ``/api/v1/projects/.../jobs`` and per-job lookup. Public kinds (``run_profile``, ``run_tests``, ``run_test_generation``) are unchanged. No DB backfill — existing rows keep their historical source labels, which are still valid for analytics. Co-Authored-By: Claude Opus 4.7 --- testgen/api/deps.py | 10 ++++++---- testgen/api/jobs.py | 4 ++-- testgen/commands/job_registry.py | 4 ++-- testgen/commands/job_runner.py | 4 ++-- testgen/commands/run_quick_start.py | 4 ++-- testgen/common/enums.py | 4 ++-- testgen/common/job_context.py | 4 +++- testgen/common/models/job_execution.py | 13 +++++++++++-- testgen/mcp/tools/execution.py | 4 ++-- testgen/scheduler/cli_scheduler.py | 4 ++-- testgen/ui/views/connections.py | 3 ++- testgen/ui/views/data_catalog.py | 3 ++- testgen/ui/views/dialogs/generate_tests_dialog.py | 3 ++- testgen/ui/views/dialogs/run_profiling_dialog.py | 3 ++- testgen/ui/views/dialogs/run_tests_dialog.py | 3 ++- testgen/ui/views/profiling_runs.py | 4 ++-- testgen/ui/views/project_settings.py | 3 ++- testgen/ui/views/table_groups.py | 5 +++-- testgen/ui/views/test_definitions.py | 3 ++- testgen/ui/views/test_runs.py | 4 ++-- testgen/ui/views/test_suites.py | 3 ++- 21 files changed, 57 insertions(+), 35 deletions(-) diff --git a/testgen/api/deps.py b/testgen/api/deps.py index 539bcff2..1807d68d 100644 --- a/testgen/api/deps.py +++ b/testgen/api/deps.py @@ -8,7 +8,7 @@ from testgen.common.auth import authorize_token, decode_jwt_token from testgen.common.models import Session, _current_session_wrapper, get_current_session -from testgen.common.models.job_execution import JobExecution +from testgen.common.models.job_execution import PUBLIC_JOB_KEYS, JobExecution from testgen.common.models.project_membership import ProjectMembership from testgen.common.models.table_group import TableGroup from testgen.common.models.test_suite import TestSuite @@ -122,13 +122,15 @@ def dependency(test_suite_id: UUID, user: User = _require_user) -> TestSuite: def resolve_job(permission: str, *extra_filters): """Resolve a JobExecution by ``job_id`` path param and verify project permission. - Internally-submitted jobs (source='system') are never exposed via the API. - Extra ORM clauses are appended to the WHERE clause to restrict by job_key. + Only jobs whose ``job_key`` is in ``PUBLIC_JOB_KEYS`` are exposed via the API. + Internal kinds (score rollups, recalculations, monitor runs) are filtered out + by construction. Extra ORM clauses are appended to the WHERE clause to further + restrict by job_key when a caller wants a single kind. """ def dependency(job_id: UUID, user: User = _require_user) -> JobExecution: query = select(JobExecution).where( JobExecution.id == job_id, - JobExecution.source != "system", + JobExecution.job_key.in_(PUBLIC_JOB_KEYS), *extra_filters, ) return _check_access(get_current_session().scalars(query).first(), user, permission) diff --git a/testgen/api/jobs.py b/testgen/api/jobs.py index 35ab4232..3ab291cf 100644 --- a/testgen/api/jobs.py +++ b/testgen/api/jobs.py @@ -12,7 +12,7 @@ ) from testgen.api.schemas import ErrorResponse, JobListResponse, JobResponse, JobSubmittedResponse from testgen.common.enums import JobKey, JobSource, JobStatus -from testgen.common.models.job_execution import JobExecution +from testgen.common.models.job_execution import PUBLIC_JOB_KEYS, JobExecution from testgen.common.models.table_group import TableGroup from testgen.common.models.test_suite import TestSuite @@ -106,7 +106,7 @@ def list_jobs( """List job executions for a project, with optional filters and pagination.""" items, total = JobExecution.list_for_project( project_code, - JobExecution.source != "system", + JobExecution.job_key.in_(PUBLIC_JOB_KEYS), job_key=job_key, status=status, page=page, diff --git a/testgen/commands/job_registry.py b/testgen/commands/job_registry.py index b0fd233f..e2bdde6c 100644 --- a/testgen/commands/job_registry.py +++ b/testgen/commands/job_registry.py @@ -20,7 +20,7 @@ from testgen.commands.run_score_update import run_score_update from testgen.commands.run_test_execution import run_test_execution from testgen.commands.test_generation import run_test_generation -from testgen.common.enums import JobStatus +from testgen.common.enums import JobSource, JobStatus from testgen.common.models import database_session from testgen.common.models.job_execution import JobExecution from testgen.common.models.profiling_run import ProfilingRun @@ -97,7 +97,7 @@ def _enqueue_score_update(job_exec: JobExecution) -> None: "parent_job_id": str(job_exec.id), "parent_job_key": job_exec.job_key, }, - source="system", + source=JobSource.system, project_code=job_exec.project_code, ) diff --git a/testgen/commands/job_runner.py b/testgen/commands/job_runner.py index 00ee523a..f95584e6 100644 --- a/testgen/commands/job_runner.py +++ b/testgen/commands/job_runner.py @@ -11,7 +11,7 @@ from sqlalchemy import select from testgen.commands.exec_job import FINAL_STATUSES, POLL_INTERVAL -from testgen.common.enums import JobStatus +from testgen.common.enums import JobSource, JobStatus from testgen.common.models import database_session, get_current_session from testgen.common.models.job_execution import JobExecution from testgen.common.models.profiling_run import ProfilingRun @@ -35,7 +35,7 @@ def submit_and_wait( job_exec = JobExecution.submit( job_key=job_key, kwargs=kwargs, - source="cli", + source=JobSource.cli, project_code=project_code, ) job_id = job_exec.id diff --git a/testgen/commands/run_quick_start.py b/testgen/commands/run_quick_start.py index 0c1d63e0..64db90d4 100644 --- a/testgen/commands/run_quick_start.py +++ b/testgen/commands/run_quick_start.py @@ -19,7 +19,7 @@ set_target_db_params, ) from testgen.common.database.flavor.flavor_service import ConnectionParams -from testgen.common.enums import JobStatus +from testgen.common.enums import JobSource, JobStatus from testgen.common.job_context import JobContext, job_context from testgen.common.models import database_session, with_database_session from testgen.common.models.job_execution import JobExecution @@ -52,7 +52,7 @@ def run_with_job_execution( effective_date = run_date or datetime.now(UTC) wall_start = datetime.now(UTC) # Match the source a real trigger would use so demo data mirrors production attribution. - source = "scheduler" if job_key == "run-monitors" else "ui" + source = JobSource.scheduler if job_key == "run-monitors" else JobSource.ui with database_session() as session: je = JobExecution( diff --git a/testgen/common/enums.py b/testgen/common/enums.py index 73804e33..15368356 100644 --- a/testgen/common/enums.py +++ b/testgen/common/enums.py @@ -38,14 +38,14 @@ class JobKey(StrEnum): class JobSource(StrEnum): - """``source`` column values for ``job_executions``. Identifies which surface - submitted the job — API client, UI, scheduler, MCP tool, CLI, or backfill.""" + """``source`` column values for ``job_executions``.""" api = "api" ui = "ui" scheduler = "scheduler" mcp = "mcp" cli = "cli" backfill = "backfill" + system = "system" class JobStatus(StrEnum): diff --git a/testgen/common/job_context.py b/testgen/common/job_context.py index e711899e..8d9b2036 100644 --- a/testgen/common/job_context.py +++ b/testgen/common/job_context.py @@ -4,11 +4,13 @@ from dataclasses import dataclass from uuid import UUID +from testgen.common.enums import JobSource + @dataclass(frozen=True) class JobContext: job_id: UUID | None = None - source: str = "CLI" + source: JobSource = JobSource.cli job_context: contextvars.ContextVar[JobContext] = contextvars.ContextVar("job_context", default=JobContext()) diff --git a/testgen/common/models/job_execution.py b/testgen/common/models/job_execution.py index 3f33e829..0f2f2660 100644 --- a/testgen/common/models/job_execution.py +++ b/testgen/common/models/job_execution.py @@ -6,12 +6,21 @@ from sqlalchemy import Column, String, Text, case, func, select, text, update from sqlalchemy.dialects import postgresql -from testgen.common.enums import JobStatus +from testgen.common.enums import JobKey, JobSource, JobStatus from testgen.common.models import Base, get_current_session LOG = logging.getLogger("testgen") +# Job kinds that are externally triggerable. Internal kinds (run-score-update, +# recalculate-project-scores, ...) are absent and filtered out by construction. +PUBLIC_JOB_KEYS: frozenset[JobKey] = frozenset({ + JobKey.run_profile, + JobKey.run_tests, + JobKey.run_test_generation, +}) + + _VALID_TRANSITIONS: dict[JobStatus, frozenset[JobStatus]] = { JobStatus.PENDING: frozenset({JobStatus.CLAIMED, JobStatus.CANCEL_REQUESTED}), JobStatus.CLAIMED: frozenset({JobStatus.RUNNING, JobStatus.ERROR, JobStatus.CANCEL_REQUESTED}), @@ -44,7 +53,7 @@ def submit( cls, job_key: str, kwargs: dict[str, Any], - source: str, + source: JobSource, project_code: str, job_schedule_id: UUID | None = None, ) -> Self: diff --git a/testgen/mcp/tools/execution.py b/testgen/mcp/tools/execution.py index 3885bed5..3ed02e0e 100644 --- a/testgen/mcp/tools/execution.py +++ b/testgen/mcp/tools/execution.py @@ -106,7 +106,8 @@ def cancel_profiling_run(job_execution_id: str) -> str: def _resolve_job_execution(job_execution_id: str, expected_job_key: JobKey, kind: str) -> JobExecution: """Resolve a user-submitted job by ID + expected job_key, collapsing missing-or-inaccessible - into one error path. Filters out source='system' jobs (internal rollups, never user-cancelable). + into one error path. Each MCP tool pins ``expected_job_key`` to a public kind, so the + job_key match alone restricts the lookup to externally-visible jobs. """ job_uuid = parse_uuid(job_execution_id, "job_execution_id") perms = get_project_permissions() @@ -114,7 +115,6 @@ def _resolve_job_execution(job_execution_id: str, expected_job_key: JobKey, kind select(JobExecution).where( JobExecution.id == job_uuid, JobExecution.job_key == expected_job_key, - JobExecution.source != "system", JobExecution.project_code.in_(perms.allowed_codes), ) ).first() diff --git a/testgen/scheduler/cli_scheduler.py b/testgen/scheduler/cli_scheduler.py index a38090ac..d5da99fe 100644 --- a/testgen/scheduler/cli_scheduler.py +++ b/testgen/scheduler/cli_scheduler.py @@ -12,7 +12,7 @@ from testgen import settings from testgen.commands.job_registry import JOB_DISPATCH, run_final_callbacks -from testgen.common.enums import JobStatus +from testgen.common.enums import JobSource, JobStatus from testgen.common.models import database_session, with_database_session from testgen.common.models.job_execution import JobExecution from testgen.common.models.scheduler import JobSchedule @@ -79,7 +79,7 @@ def start_job(self, job: CliJob, triggering_time: datetime) -> None: JobExecution.submit( job_key=job.key, kwargs=job.kwargs, - source="scheduler", + source=JobSource.scheduler, project_code=job.project_code, job_schedule_id=job.job_schedule_id, ) diff --git a/testgen/ui/views/connections.py b/testgen/ui/views/connections.py index 83b548e5..a1452bb6 100644 --- a/testgen/ui/views/connections.py +++ b/testgen/ui/views/connections.py @@ -19,6 +19,7 @@ from testgen import settings from testgen.common.database.database_service import empty_cache, get_flavor_service from testgen.common.database.flavor.flavor_service import resolve_connection_params +from testgen.common.enums import JobSource from testgen.common.models import get_current_session, with_database_session from testgen.common.models.connection import Connection, ConnectionMinimal from testgen.common.models.job_execution import JobExecution @@ -471,7 +472,7 @@ def on_close_clicked(_params: dict) -> None: JobExecution.submit( job_key="run-profile", kwargs={"table_group_id": str(table_group.id)}, - source="ui", + source=JobSource.ui, project_code=table_group.project_code, ) message = f"Profiling run started for table group {table_group.table_groups_name}." diff --git a/testgen/ui/views/data_catalog.py b/testgen/ui/views/data_catalog.py index 5327a081..0d02cd7b 100644 --- a/testgen/ui/views/data_catalog.py +++ b/testgen/ui/views/data_catalog.py @@ -11,6 +11,7 @@ from streamlit.delta_generator import DeltaGenerator from testgen.common.database.database_service import get_flavor_service +from testgen.common.enums import JobSource from testgen.common.models import database_session, with_database_session from testgen.common.models.connection import Connection from testgen.common.models.job_execution import JobExecution @@ -143,7 +144,7 @@ def on_run_profiling_confirmed(table_group: dict) -> None: JobExecution.submit( job_key="run-profile", kwargs={"table_group_id": str(table_group["id"])}, - source="ui", + source=JobSource.ui, project_code=project_code, ) except Exception as error: diff --git a/testgen/ui/views/dialogs/generate_tests_dialog.py b/testgen/ui/views/dialogs/generate_tests_dialog.py index e894dcb8..cf4cc715 100644 --- a/testgen/ui/views/dialogs/generate_tests_dialog.py +++ b/testgen/ui/views/dialogs/generate_tests_dialog.py @@ -1,5 +1,6 @@ import streamlit as st +from testgen.common.enums import JobSource from testgen.common.models import database_session, with_database_session from testgen.common.models.job_execution import JobExecution from testgen.common.models.test_suite import TestSuiteMinimal @@ -42,7 +43,7 @@ def on_generate_tests_confirmed(data: dict) -> None: JobExecution.submit( job_key="run-test-generation", kwargs={"test_suite_id": str(test_suite_id), "generation_set": selected_set}, - source="ui", + source=JobSource.ui, project_code=test_suite.project_code, ) st.session_state[RESULT_KEY] = {"success": True, "message": f"Test generation started for test suite '{test_suite_name}'."} diff --git a/testgen/ui/views/dialogs/run_profiling_dialog.py b/testgen/ui/views/dialogs/run_profiling_dialog.py index 9dfbb3ff..80693b94 100644 --- a/testgen/ui/views/dialogs/run_profiling_dialog.py +++ b/testgen/ui/views/dialogs/run_profiling_dialog.py @@ -2,6 +2,7 @@ import streamlit as st +from testgen.common.enums import JobSource from testgen.common.models import database_session from testgen.common.models.job_execution import JobExecution from testgen.ui.components import widgets as testgen @@ -32,7 +33,7 @@ def on_run_profiling_confirmed(table_group: dict) -> None: JobExecution.submit( job_key="run-profile", kwargs={"table_group_id": str(table_group["id"])}, - source="ui", + source=JobSource.ui, project_code=project_code, ) except Exception as error: diff --git a/testgen/ui/views/dialogs/run_tests_dialog.py b/testgen/ui/views/dialogs/run_tests_dialog.py index 29b224e9..33797485 100644 --- a/testgen/ui/views/dialogs/run_tests_dialog.py +++ b/testgen/ui/views/dialogs/run_tests_dialog.py @@ -1,5 +1,6 @@ import streamlit as st +from testgen.common.enums import JobSource from testgen.common.models import database_session from testgen.common.models.job_execution import JobExecution from testgen.common.models.test_suite import TestSuite @@ -34,7 +35,7 @@ def on_run_tests_confirmed(data: dict) -> None: JobExecution.submit( job_key="run-tests", kwargs={"test_suite_id": str(selected_id)}, - source="ui", + source=JobSource.ui, project_code=project_code, ) except Exception as e: diff --git a/testgen/ui/views/profiling_runs.py b/testgen/ui/views/profiling_runs.py index b9f036bf..af705c6f 100644 --- a/testgen/ui/views/profiling_runs.py +++ b/testgen/ui/views/profiling_runs.py @@ -13,7 +13,7 @@ RUN_NOTIFICATIONS_DIALOG_OPEN_COUNT_KEY = "pr:run_notifications_dialog_open_count" import testgen.ui.services.form_service as fm -from testgen.common.enums import JobStatus +from testgen.common.enums import JobSource, JobStatus from testgen.common.models import database_session, get_current_session, with_database_session from testgen.common.models.job_execution import JobExecution from testgen.common.models.notification_settings import ( @@ -115,7 +115,7 @@ def on_run_profiling_confirmed(table_group: dict) -> None: JobExecution.submit( job_key="run-profile", kwargs={"table_group_id": str(table_group["id"])}, - source="ui", + source=JobSource.ui, project_code=project_code, ) except Exception as error: diff --git a/testgen/ui/views/project_settings.py b/testgen/ui/views/project_settings.py index 39f98abd..db0b5011 100644 --- a/testgen/ui/views/project_settings.py +++ b/testgen/ui/views/project_settings.py @@ -5,6 +5,7 @@ import streamlit as st from testgen.commands.run_observability_exporter import test_observability_exporter +from testgen.common.enums import JobSource from testgen.common.models import with_database_session from testgen.common.models.job_execution import JobExecution from testgen.common.models.project import Project @@ -81,7 +82,7 @@ def update_project(self, project_code: str, edited_project: dict) -> None: JobExecution.submit( job_key="recalculate-project-scores", kwargs={"project_code": project_code}, - source="ui", + source=JobSource.ui, project_code=project_code, ) st.toast("Scores will be recalculated in the background.") diff --git a/testgen/ui/views/table_groups.py b/testgen/ui/views/table_groups.py index eef3b7d8..2b909117 100644 --- a/testgen/ui/views/table_groups.py +++ b/testgen/ui/views/table_groups.py @@ -7,6 +7,7 @@ from sqlalchemy.exc import IntegrityError from testgen.commands.test_generation import run_monitor_generation +from testgen.common.enums import JobSource from testgen.common.models import get_current_session, with_database_session from testgen.common.models.connection import Connection from testgen.common.models.job_execution import JobExecution @@ -136,7 +137,7 @@ def on_run_profiling_confirmed(table_group: dict) -> None: JobExecution.submit( job_key="run-profile", kwargs={"table_group_id": str(table_group["id"])}, - source="ui", + source=JobSource.ui, project_code=project_code, ) except Exception as error: @@ -448,7 +449,7 @@ def on_close_clicked(_params: dict) -> None: JobExecution.submit( job_key="run-profile", kwargs={"table_group_id": str(table_group.id)}, - source="ui", + source=JobSource.ui, project_code=table_group.project_code, ) message = f"Profiling run started for table group {table_group.table_groups_name}." diff --git a/testgen/ui/views/test_definitions.py b/testgen/ui/views/test_definitions.py index 3fe28055..a73f30c9 100644 --- a/testgen/ui/views/test_definitions.py +++ b/testgen/ui/views/test_definitions.py @@ -10,6 +10,7 @@ from testgen.common import date_service from testgen.common.custom_test_validation import validate_custom_query from testgen.common.database.database_service import get_flavor_service +from testgen.common.enums import JobSource from testgen.common.models import with_database_session from testgen.common.models.connection import Connection from testgen.common.models.job_execution import JobExecution @@ -447,7 +448,7 @@ def on_run_tests_confirmed(data: dict) -> None: JobExecution.submit( job_key="run-tests", kwargs={"test_suite_id": str(selected_id)}, - source="ui", + source=JobSource.ui, project_code=project_code, ) except Exception as error: diff --git a/testgen/ui/views/test_runs.py b/testgen/ui/views/test_runs.py index 408ac797..2d490069 100644 --- a/testgen/ui/views/test_runs.py +++ b/testgen/ui/views/test_runs.py @@ -6,7 +6,7 @@ import streamlit as st import testgen.ui.services.form_service as fm -from testgen.common.enums import JobStatus +from testgen.common.enums import JobSource, JobStatus from testgen.common.models import database_session, get_current_session, with_database_session from testgen.common.models.job_execution import JobExecution from testgen.common.models.notification_settings import ( @@ -109,7 +109,7 @@ def on_run_tests_confirmed(data: dict) -> None: JobExecution.submit( job_key="run-tests", kwargs={"test_suite_id": str(selected_id)}, - source="ui", + source=JobSource.ui, project_code=project_code, ) except Exception as error: diff --git a/testgen/ui/views/test_suites.py b/testgen/ui/views/test_suites.py index 605774f6..c32c1b88 100644 --- a/testgen/ui/views/test_suites.py +++ b/testgen/ui/views/test_suites.py @@ -4,6 +4,7 @@ from testgen.commands.run_observability_exporter import export_test_results from testgen.commands.test_generation import run_test_generation +from testgen.common.enums import JobSource from testgen.common.models import database_session, with_database_session from testgen.common.models.job_execution import JobExecution from testgen.common.models.notification_settings import TestRunNotificationSettings @@ -173,7 +174,7 @@ def on_run_tests_confirmed(data: dict) -> None: JobExecution.submit( job_key="run-tests", kwargs={"test_suite_id": str(selected_id)}, - source="ui", + source=JobSource.ui, project_code=project_code, ) except Exception as error: From 470fc1eabf13b230db886a95f91efe06042228b0 Mon Sep 17 00:00:00 2001 From: Luis Date: Wed, 13 May 2026 13:23:30 -0400 Subject: [PATCH 16/58] fix(scorecards): filter categories by CDE --- testgen/common/models/scores.py | 5 +- .../common/models/test_score_definition.py | 113 ++++++++++++++++++ 2 files changed, 117 insertions(+), 1 deletion(-) create mode 100644 tests/unit/common/models/test_score_definition.py diff --git a/testgen/common/models/scores.py b/testgen/common/models/scores.py index b5fc9545..cc693244 100644 --- a/testgen/common/models/scores.py +++ b/testgen/common/models/scores.py @@ -259,6 +259,9 @@ def as_score_card(self) -> ScoreCard: ).replace("{filters}", filters)) ).mappings().first() or {} + cde_only_categories = self.cde_score and not self.total_score + category_filters = " AND ".join(self._get_raw_query_filters(cde_only=cde_only_categories)) + categories_scores = [] if (category := self.category): categories_scores = [ @@ -267,7 +270,7 @@ def as_score_card(self) -> ScoreCard: text(read_template_sql_file( categories_query_template_file, sub_directory="score_cards", - ).replace("{category}", category.value).replace("{filters}", filters)) + ).replace("{category}", category.value).replace("{filters}", category_filters)) ).mappings().all() ] diff --git a/tests/unit/common/models/test_score_definition.py b/tests/unit/common/models/test_score_definition.py new file mode 100644 index 00000000..1080e37e --- /dev/null +++ b/tests/unit/common/models/test_score_definition.py @@ -0,0 +1,113 @@ +"""Tests for ScoreDefinition.as_score_card() filter behavior across toggle combinations. + +Covers TG-1078: in CDE-only mode (total_score OFF, cde_score ON) the per-category +scores must be computed over CDE columns only. In all other modes the per-category +scores must be computed over the full column universe (no CDE filter). +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from testgen.common.models.scores import ( + ScoreCategory, + ScoreDefinition, + ScoreDefinitionCriteria, + ScoreDefinitionFilter, +) + +pytestmark = pytest.mark.unit + + +CDE_FILTER_FRAGMENT = "critical_data_element = true" + + +def _make_definition( + *, + total_score: bool, + cde_score: bool, + category: ScoreCategory = ScoreCategory.dq_dimension, +) -> ScoreDefinition: + definition = ScoreDefinition( + project_code="demo", + name="Test card", + total_score=total_score, + cde_score=cde_score, + category=category, + ) + definition.criteria = ScoreDefinitionCriteria( + operand="AND", + group_by_field=True, + filters=[ScoreDefinitionFilter(field="table_groups_name", value="my_group")], + ) + return definition + + +def _capture_executed_sql(definition: ScoreDefinition) -> list[str]: + """Run as_score_card() against a mocked session and return the SQL of each execute call.""" + session = MagicMock() + mappings_result = MagicMock() + mappings_result.first.return_value = {} + mappings_result.all.return_value = [] + session.execute.return_value.mappings.return_value = mappings_result + + with patch("testgen.common.models.scores.get_current_session", return_value=session): + definition.as_score_card() + + return [str(call.args[0]) for call in session.execute.call_args_list] + + +@pytest.mark.parametrize( + "category", + [ScoreCategory.dq_dimension, ScoreCategory.impact_dimension, ScoreCategory.business_domain], +) +def test_categories_query_omits_cde_filter_in_total_only_mode(category): + definition = _make_definition(total_score=True, cde_score=False, category=category) + sql_calls = _capture_executed_sql(definition) + + assert len(sql_calls) == 2, "expected one overall and one categories query" + overall_sql, categories_sql = sql_calls + assert CDE_FILTER_FRAGMENT not in categories_sql + assert CDE_FILTER_FRAGMENT not in overall_sql + + +@pytest.mark.parametrize( + "category", + [ScoreCategory.dq_dimension, ScoreCategory.impact_dimension, ScoreCategory.business_domain], +) +def test_categories_query_omits_cde_filter_in_total_and_cde_mode(category): + definition = _make_definition(total_score=True, cde_score=True, category=category) + sql_calls = _capture_executed_sql(definition) + + assert len(sql_calls) == 2 + overall_sql, categories_sql = sql_calls + assert CDE_FILTER_FRAGMENT not in categories_sql + assert CDE_FILTER_FRAGMENT not in overall_sql + + +@pytest.mark.parametrize( + "category", + [ScoreCategory.dq_dimension, ScoreCategory.impact_dimension, ScoreCategory.business_domain], +) +def test_categories_query_includes_cde_filter_in_cde_only_mode(category): + definition = _make_definition(total_score=False, cde_score=True, category=category) + sql_calls = _capture_executed_sql(definition) + + assert len(sql_calls) == 2 + overall_sql, categories_sql = sql_calls + assert CDE_FILTER_FRAGMENT in categories_sql, ( + "Categories query must filter by CDE columns when the card is in CDE-only mode" + ) + # Overall query must stay un-filtered by CDE — it selects score and cde_score as + # separate columns, so adding the filter would zero out the non-CDE total. + assert CDE_FILTER_FRAGMENT not in overall_sql + + +def test_categories_query_uses_column_template_for_column_category(): + definition = _make_definition(total_score=False, cde_score=True, category=ScoreCategory.business_domain) + sql_calls = _capture_executed_sql(definition) + + categories_sql = sql_calls[1] + # Column-grouped template aggregates by a placeholder substituted into the SELECT. + assert "business_domain" in categories_sql + assert CDE_FILTER_FRAGMENT in categories_sql From 39d4798a127a88644a5b1c979f07824e5b27d23b Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Wed, 13 May 2026 14:51:42 -0400 Subject: [PATCH 17/58] fix: drop args column from quick-start seed insert MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The seed INSERT for ``job_schedules`` still listed the ``args`` column, which migration 0188 dropped. ``testgen quick-start`` (and the Quickstart Test / UI Flows CI jobs that exercise it) failed with ``column "args" does not exist`` on a fresh install. Drop ``args`` from both the column list and the SELECT projection. Fix a pre-existing cosmetic typo while in here: ``TRUE AS TRUE`` → ``TRUE AS active`` (positional INSERT, so the previous alias was ignored at runtime). Co-Authored-By: Claude Opus 4.7 --- testgen/template/quick_start/initial_data_seeding.sql | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/testgen/template/quick_start/initial_data_seeding.sql b/testgen/template/quick_start/initial_data_seeding.sql index fb1283ca..0b5720dd 100644 --- a/testgen/template/quick_start/initial_data_seeding.sql +++ b/testgen/template/quick_start/initial_data_seeding.sql @@ -58,15 +58,14 @@ SELECT '823a1fef-9b6d-48d5-9d0f-2db9812cc318'::UUID AS id, 30 AS predict_min_lookback; INSERT INTO job_schedules - (id, project_code, key, args, kwargs, cron_expr, cron_tz, active) + (id, project_code, key, kwargs, cron_expr, cron_tz, active) SELECT 'eac9d722-d06a-4b1f-b8c4-bb2854bd4cfd'::UUID AS id, '{PROJECT_CODE}' AS project_code, 'run-monitors' AS key, - '[]'::JSONB AS args, '{"test_suite_id": "823a1fef-9b6d-48d5-9d0f-2db9812cc318"}'::JSONB AS kwargs, '0 */12 * * *' AS cron_expr, 'UTC' AS cron_tz, - TRUE AS TRUE; + TRUE AS active; UPDATE table_groups SET monitor_test_suite_id = '823a1fef-9b6d-48d5-9d0f-2db9812cc318'::UUID From f3a1582bed1cc2121998f40375913eacb93d5ecd Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Thu, 14 May 2026 22:47:18 -0400 Subject: [PATCH 18/58] fix(standalone): resolve embedded host/port at connection-build time MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Windows pgserver picks a fresh ephemeral TCP port on every startup, so the demo-DB connection row written during quick-start became stale as soon as run-app started a new pgserver session — "Test Connection" and any target-DB query failed with "connection refused". Store a sentinel in project_host instead of the live host/port; resolve_connection_params rewrites it to the live values when standalone mode is active. Single chokepoint covers UI test-connection, profiling, test execution, and quick-start increments. Co-Authored-By: Claude Opus 4.7 (1M context) --- testgen/commands/run_launch_db_config.py | 9 +++++---- testgen/commands/run_quick_start.py | 9 +++++---- testgen/common/database/flavor/flavor_service.py | 11 +++++++++-- testgen/common/standalone_postgres.py | 7 +++++++ 4 files changed, 26 insertions(+), 10 deletions(-) diff --git a/testgen/commands/run_launch_db_config.py b/testgen/commands/run_launch_db_config.py index 11b2257f..f7b69bcb 100644 --- a/testgen/commands/run_launch_db_config.py +++ b/testgen/commands/run_launch_db_config.py @@ -9,7 +9,7 @@ from testgen.common.models import with_database_session from testgen.common.read_file import get_template_files from testgen.common.read_yaml_metadata_records import import_metadata_records_from_yaml -from testgen.common.standalone_postgres import get_target_host_port, is_standalone_mode +from testgen.common.standalone_postgres import EMBEDDED_HOST_SENTINEL, is_standalone_mode LOG = logging.getLogger("testgen") @@ -28,9 +28,10 @@ def _get_params_mapping() -> dict: project_user = settings.PROJECT_DATABASE_USER project_password = settings.PROJECT_DATABASE_PASSWORD if is_standalone_mode(): - project_host, server_port = get_target_host_port() - if server_port: - project_port = server_port + # Live host/port are resolved at connection-build time via the sentinel + # so the row survives pgserver picking a fresh ephemeral port on Windows. + project_host = EMBEDDED_HOST_SENTINEL + project_port = "" project_user = "postgres" project_password = "" diff --git a/testgen/commands/run_quick_start.py b/testgen/commands/run_quick_start.py index 64db90d4..ae4940ff 100644 --- a/testgen/commands/run_quick_start.py +++ b/testgen/commands/run_quick_start.py @@ -28,7 +28,7 @@ from testgen.common.models.table_group import TableGroup from testgen.common.notifications.base import smtp_configured from testgen.common.read_file import read_template_sql_file -from testgen.common.standalone_postgres import get_target_host_port, is_standalone_mode +from testgen.common.standalone_postgres import EMBEDDED_HOST_SENTINEL, is_standalone_mode LOG = logging.getLogger("testgen") random.seed(42) @@ -149,9 +149,10 @@ def _get_settings_params_mapping() -> dict: admin_user = settings.DATABASE_ADMIN_USER admin_password = settings.DATABASE_ADMIN_PASSWORD if is_standalone_mode(): - host, server_port = get_target_host_port() - if server_port: - port = server_port + # Live host/port are resolved at connection-build time via the sentinel + # so the row survives pgserver picking a fresh ephemeral port on Windows. + host = EMBEDDED_HOST_SENTINEL + port = "" admin_user = "postgres" admin_password = "" diff --git a/testgen/common/database/flavor/flavor_service.py b/testgen/common/database/flavor/flavor_service.py index 09406ea4..0f4f576b 100644 --- a/testgen/common/database/flavor/flavor_service.py +++ b/testgen/common/database/flavor/flavor_service.py @@ -61,13 +61,20 @@ def _decrypt_if_needed(value: Any) -> str | None: def resolve_connection_params(connection_params: ConnectionParams) -> ResolvedConnectionParams: sql_flavor = connection_params.get("sql_flavor") or "" + host = connection_params.get("project_host") or "" + port = connection_params.get("project_port") or "" + # Lazy import to keep the flavor layer free of standalone concerns at module load. + from testgen.common.standalone_postgres import EMBEDDED_HOST_SENTINEL, get_target_host_port, is_standalone_mode + if host == EMBEDDED_HOST_SENTINEL and is_standalone_mode(): + host, live_port = get_target_host_port() + port = live_port or "" return ResolvedConnectionParams( url=connection_params.get("url") or "", connect_by_url=connection_params.get("connect_by_url", False), username=connection_params.get("project_user") or "", password=_decrypt_if_needed(connection_params.get("project_pw_encrypted")), - host=connection_params.get("project_host") or "", - port=connection_params.get("project_port") or "", + host=host, + port=port, dbname=connection_params.get("project_db") or "", dbschema=connection_params.get("table_group_schema"), sql_flavor=sql_flavor, diff --git a/testgen/common/standalone_postgres.py b/testgen/common/standalone_postgres.py index d272eb94..be27e774 100644 --- a/testgen/common/standalone_postgres.py +++ b/testgen/common/standalone_postgres.py @@ -23,6 +23,13 @@ HOME_DIR_ENV_VAR = "TG_TESTGEN_HOME" STANDALONE_URI_ENV_VAR = "_TG_STANDALONE_URI" +# Stored as ``project_host`` in the demo-DB connection row so that the actual +# host/port — which can change across sessions on Windows (pgserver picks a +# fresh ephemeral TCP port each startup) — are resolved at connection-build +# time by ``resolve_connection_params``. Angle brackets are illegal in real +# hostnames, so this can't collide with a user-defined connection. +EMBEDDED_HOST_SENTINEL = "" + def get_home_dir() -> Path: env_dir = os.getenv(HOME_DIR_ENV_VAR) From 1a6150d28bed7a58bd5fdb8a2ac00f6a1f3e1f76 Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Fri, 15 May 2026 00:12:10 -0400 Subject: [PATCH 19/58] fix(standalone): revert Windows signal forwarding to TerminateProcess MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CTRL_BREAK_EVENT delivery is unreliable on Windows — the scheduler's threading.Event.wait doesn't wake on SIGBREAK from a different thread context, so the parent's children-watcher loop never empties and Ctrl+C hangs the whole tree. Reported by Chip; reproduced on the AWS WorkSpace. The orphan-postgres bug that motivated the CTRL_BREAK_EVENT switch is already addressed by the other half of fde7321 — children call ensure_standalone_setup() instead of get_server(), so they no longer register PIDs in pgserver's on-disk handle list. Force-killing them is safe; only the parent owns the pgserver handle and exits via the normal atexit path. Co-Authored-By: Claude Opus 4.7 (1M context) --- testgen/__main__.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/testgen/__main__.py b/testgen/__main__.py index 3b8d7f12..a1ea67a1 100644 --- a/testgen/__main__.py +++ b/testgen/__main__.py @@ -146,12 +146,16 @@ def _install_shutdown_handler(handler) -> None: def _forward_signal_to_child(child: subprocess.Popen, signum: int) -> None: - # POSIX: forward the signal verbatim. Windows: send CTRL_BREAK_EVENT so the - # child's atexit hooks run (it deregisters from pgserver's PID list, lets - # streamlit/uvicorn shut down their event loops, etc.). The child must have - # been spawned with CREATE_NEW_PROCESS_GROUP — see _subprocess_spawn_kwargs. + # POSIX: forward the signal verbatim. Windows: TerminateProcess — graceful + # CTRL_BREAK_EVENT delivery is unreliable here because some children + # (notably the scheduler thread blocked in threading.Event.wait, and + # Streamlit/tornado) don't wake on SIGBREAK, leaving the parent's + # children-watcher loop hung forever. Force-kill is safe: children no + # longer call pgserver.get_server(), so there's no on-disk PID registry + # state to clean up — only the parent owns the pgserver handle and + # exits via the normal atexit path. if sys.platform == "win32": - child.send_signal(signal.CTRL_BREAK_EVENT) + child.terminate() else: child.send_signal(signum) From 8d9b60006389eaf58873ca207a7bdccf0d992704 Mon Sep 17 00:00:00 2001 From: Astor Date: Tue, 28 Apr 2026 09:33:18 -0300 Subject: [PATCH 20/58] feat: add server-side pagination for test definitions (TG-1041) Implements DB-level pagination via _paginate() override in TestDefinition model, with @st.cache_data caching in the UI layer. Replaces full-table loads with paginated queries to handle large test definition sets. Co-Authored-By: Claude Sonnet 4.6 --- testgen/common/models/entity.py | 34 ++++ testgen/common/models/test_definition.py | 39 ++++- testgen/ui/views/test_definitions.py | 47 +++--- .../models/test_test_definition_pagination.py | 145 ++++++++++++++++++ 4 files changed, 242 insertions(+), 23 deletions(-) create mode 100644 tests/unit/common/models/test_test_definition_pagination.py diff --git a/testgen/common/models/entity.py b/testgen/common/models/entity.py index 8f055bda..98671965 100644 --- a/testgen/common/models/entity.py +++ b/testgen/common/models/entity.py @@ -95,6 +95,34 @@ def select_where(cls, *clauses, order_by: tuple[str | InstrumentedAttribute] | N query = select(cls).where(*clauses).order_by(*order_by) return get_current_session().scalars(query).all() + @classmethod + @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) + def _paginate( + cls, + *clauses, + order_by: tuple[str | InstrumentedAttribute] | None = None, + page_index: int = 0, + page_size: int = 500, + ) -> tuple[list[Self], int]: + """Fetch one page of rows plus the total matching count via a window function. + + Uses ``COUNT(*) OVER()`` so only one round-trip to the database is needed. + Returns ``(items, total_count)``. ``page_index`` is 0-based. + """ + order_by = order_by or cls._default_order_by + total_col = func.count().over().label("total_count") + query = ( + select(cls, total_col) + .where(*clauses) + .order_by(*order_by) + .offset(page_index * page_size) + .limit(page_size) + ) + rows = get_current_session().execute(query).all() + items = [row[0] for row in rows] + total = rows[0][1] if rows else 0 + return items, total + @classmethod def select_minimal_where(cls, *clauses, order_by: tuple[str | InstrumentedAttribute]) -> Iterable[Any]: raise NotImplementedError @@ -164,6 +192,12 @@ def is_in_use(cls, ids: list[str]) -> bool: def cascade_delete(cls, ids: list[str]) -> None: raise NotImplementedError + @classmethod + def clear_cache(cls) -> None: + cls.get.clear() + cls.select_where.clear() + cls._paginate.clear() + @classmethod def columns(cls) -> list[str]: return list(cls.__annotations__.keys()) diff --git a/testgen/common/models/test_definition.py b/testgen/common/models/test_definition.py index f10a6b31..9a73717a 100644 --- a/testgen/common/models/test_definition.py +++ b/testgen/common/models/test_definition.py @@ -419,8 +419,9 @@ def list_for_suite( getattr(cls, col, None) or getattr(TestType, col) if isinstance(col, str) else col for col in cls._summary_columns ] + total_col = func.count().over().label("total_count") query = ( - select(*select_columns) + select(*select_columns, total_col) .join(TestType, cls.test_type == TestType.test_type) .join(TestSuite, cls.test_suite_id == TestSuite.id) .where(cls.test_suite_id == test_suite_id, TestSuite.is_monitor.isnot(True)) @@ -433,8 +434,40 @@ def list_for_suite( query = query.where(cls.test_type == test_type) if test_active is not None: query = query.where(cls.test_active == test_active) - query = query.order_by(*cls._default_order_by) - return cls._paginate(query, page=page, limit=limit, data_class=TestDefinitionSummary) + query = query.order_by(*cls._default_order_by).offset((page - 1) * limit).limit(limit) + rows = get_current_session().execute(query).mappings().all() + items = [TestDefinitionSummary(**{k: v for k, v in row.items() if k != "total_count"}) for row in rows] + total = rows[0]["total_count"] if rows else 0 + return items, total + + @classmethod + @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) + def _paginate( + cls, + *clauses, + order_by: tuple[str | InstrumentedAttribute] | None = None, + page_index: int = 0, + page_size: int = 500, + ) -> tuple[list["TestDefinitionSummary"], int]: + select_columns = [ + getattr(cls, col, None) or getattr(TestType, col) if isinstance(col, str) else col + for col in cls._summary_columns + ] + total_col = func.count().over().label("total_count") + query = ( + select(*select_columns, total_col) + .join(TestType, cls.test_type == TestType.test_type) + .where(*clauses) + .order_by(*(order_by or cls._default_order_by)) + .offset(page_index * page_size) + .limit(page_size) + ) + rows = get_current_session().execute(query).mappings().all() + items = [TestDefinitionSummary(**{k: v for k, v in row.items() if k != "total_count"}) for row in rows] + total = rows[0]["total_count"] if rows else 0 + return items, total + + _yn_columns: ClassVar = {"test_active", "lock_refresh"} diff --git a/testgen/ui/views/test_definitions.py b/testgen/ui/views/test_definitions.py index a73f30c9..e8d066d1 100644 --- a/testgen/ui/views/test_definitions.py +++ b/testgen/ui/views/test_definitions.py @@ -143,11 +143,9 @@ def render( with st.spinner("Loading data ..."): user_can_edit = session.auth.user_has_permission("edit") user_can_disposition = session.auth.user_has_permission("disposition") - df = get_test_definitions(test_suite, table_name, column_name, test_type, sorting_columns, - page=current_page, page_size=current_page_size, - flagged_filter=flagged) - total_count = get_test_definitions_count(test_suite, table_name, column_name, test_type, - flagged_filter=flagged) + df, total_count = get_test_definitions(test_suite, table_name, column_name, test_type, sorting_columns, + page_index=current_page, page_size=current_page_size, + flagged_filter=flagged) test_types = run_test_type_lookup_query().to_dict("records") table_columns = get_columns(str(table_group.id)) filter_columns_df = get_test_suite_columns(test_suite_id) @@ -532,7 +530,7 @@ def on_export_filtered(payload: dict) -> None: def on_export_selected(payload: dict) -> None: ids = payload.get("ids", []) if ids: - data = get_test_definitions(test_suite) + data, _ = get_test_definitions(test_suite) data = data[data["id"].isin(ids)] download_dialog( dialog_title="Download Excel Report", @@ -681,7 +679,7 @@ def get_excel_report_data( if data is not None: data = data.copy() else: - data = get_test_definitions(test_suite) + data, _ = get_test_definitions(test_suite) for key in ["test_active_display", "lock_refresh_display", "flagged_display"]: data[key] = data[key].apply(lambda val: val if val == "Yes" else None) @@ -774,10 +772,17 @@ def get_test_definitions( column_name: str | None = None, test_type: str | None = None, sorting_columns: list[tuple] | None = None, - page: int = 0, - page_size: int = 0, + page_index: int | None = None, + page_size: int = 500, flagged_filter: str | None = None, -) -> pd.DataFrame: +) -> tuple[pd.DataFrame, int]: + """Return ``(df, total_count)`` for one page of test definitions. + + When ``page_index`` is provided (0-based), fetches only that page from + the DB using ``TestDefinition._paginate()``; otherwise fetches all rows + via ``select_where()``. ``total_count`` is the full count of matching + rows regardless of which page is requested. + """ clauses = [TestDefinition.test_suite_id == test_suite.id] if table_name: clauses.append(TestDefinition.table_name == table_name) @@ -805,16 +810,18 @@ def get_test_definitions( else: order_by.append(sort_funcs[direction](func.lower(getattr(TestDefinition, attribute)))) - # For pagination, we need to bypass the base select_where which doesn't support offset/limit. - # We'll fetch all matching results and slice in Python. - test_definitions = TestDefinition.select_where( - *clauses, - order_by=tuple(order_by) if order_by else None, - ) + order_by_tuple = tuple(order_by) if order_by else None - if page_size > 0: - offset = page * page_size - test_definitions = list(test_definitions)[offset:offset + page_size] + if page_index is not None: + test_definitions, total_count = TestDefinition._paginate( + *clauses, + order_by=order_by_tuple, + page_index=page_index, + page_size=page_size, + ) + else: + test_definitions = TestDefinition.select_where(*clauses, order_by=order_by_tuple) + total_count = len(test_definitions) df = to_dataframe(test_definitions, TestDefinitionSummary.columns()) date_service.accommodate_dataframe_to_timezone(df, st.session_state) @@ -846,7 +853,7 @@ def get_export_to_observability_display(value: str) -> str: for col in df.select_dtypes(include=["datetime"]).columns: df[col] = df[col].astype(str).replace("NaT", "") - return df + return df, total_count def get_test_definitions_count( diff --git a/tests/unit/common/models/test_test_definition_pagination.py b/tests/unit/common/models/test_test_definition_pagination.py new file mode 100644 index 00000000..9f8edefd --- /dev/null +++ b/tests/unit/common/models/test_test_definition_pagination.py @@ -0,0 +1,145 @@ +from datetime import datetime +from unittest.mock import patch +from uuid import uuid4 + +import pytest + +from testgen.common.models.test_definition import TestDefinition, TestDefinitionSummary + +pytestmark = pytest.mark.unit + + +@pytest.fixture(autouse=True) +def clear_streamlit_cache(): + TestDefinition._paginate.clear() + yield + + +def _make_row(table_name: str = "my_table", total_count: int = 10) -> dict: + """Return a minimal row dict as returned by session.execute().mappings().all().""" + return { + # TestDefinitionSummary fields + "id": uuid4(), + "table_groups_id": uuid4(), + "profile_run_id": uuid4(), + "test_type": "CUSTOM", + "test_suite_id": uuid4(), + "test_description": None, + "schema_name": "public", + "table_name": table_name, + "column_name": "col1", + "skip_errors": 0, + "baseline_ct": None, + "baseline_unique_ct": None, + "baseline_value": None, + "baseline_value_ct": None, + "threshold_value": None, + "baseline_sum": None, + "baseline_avg": None, + "baseline_sd": None, + "lower_tolerance": None, + "upper_tolerance": None, + "subset_condition": None, + "groupby_names": None, + "having_condition": None, + "window_date_column": None, + "window_days": None, + "match_schema_name": None, + "match_table_name": None, + "match_column_names": None, + "match_subset_condition": None, + "match_groupby_names": None, + "match_having_condition": None, + "custom_query": None, + "history_calculation": None, + "history_calculation_upper": None, + "history_lookback": None, + "test_active": True, + "test_definition_status": None, + "severity": None, + "lock_refresh": False, + "last_auto_gen_date": None, + "profiling_as_of_date": None, + "last_manual_update": datetime.now(), + "export_to_observability": False, + "prediction": None, + "flagged": False, + # TestTypeSummary fields + "test_name_short": "Custom", + "default_test_description": "A test", + "measure_uom": "", + "measure_uom_description": "", + "default_parm_columns": "", + "default_parm_prompts": "", + "default_parm_help": "", + "default_parm_required": "", + "default_severity": "Warning", + "test_scope": "column", + "dq_dimension": "", + "usage_notes": "", + # Window function extra column + "total_count": total_count, + } + + +@patch("testgen.common.models.test_definition.get_current_session") +def test__paginate_returns_items_and_total(mock_get_session): + rows = [_make_row("table_a", total_count=3), _make_row("table_b", total_count=3), _make_row("table_c", total_count=3)] + mock_get_session.return_value.execute.return_value.mappings.return_value.all.return_value = rows + + items, total = TestDefinition._paginate() + + assert total == 3 + assert len(items) == 3 + assert all(isinstance(item, TestDefinitionSummary) for item in items) + assert items[0].table_name == "table_a" + assert items[2].table_name == "table_c" + + +@patch("testgen.common.models.test_definition.get_current_session") +def test__paginate_empty_result_returns_zero_total(mock_get_session): + mock_get_session.return_value.execute.return_value.mappings.return_value.all.return_value = [] + + items, total = TestDefinition._paginate() + + assert items == [] + assert total == 0 + + +@patch("testgen.common.models.test_definition.get_current_session") +def test__paginate_total_count_not_in_item_fields(mock_get_session): + mock_get_session.return_value.execute.return_value.mappings.return_value.all.return_value = [_make_row()] + + items, _ = TestDefinition._paginate() + + assert not hasattr(items[0], "total_count") + + +@patch("testgen.common.models.test_definition.get_current_session") +def test__paginate_uses_correct_offset_and_limit(mock_get_session): + mock_get_session.return_value.execute.return_value.mappings.return_value.all.return_value = [] + + TestDefinition._paginate(page_index=2, page_size=100) + + call_args = mock_get_session.return_value.execute.call_args + query = call_args[0][0] + compiled = query.compile(compile_kwargs={"literal_binds": True}) + sql = str(compiled) + + assert "LIMIT 100" in sql + assert "OFFSET 200" in sql + + +@patch("testgen.common.models.test_definition.get_current_session") +def test__paginate_page_zero_has_no_offset(mock_get_session): + mock_get_session.return_value.execute.return_value.mappings.return_value.all.return_value = [] + + TestDefinition._paginate(page_index=0, page_size=500) + + call_args = mock_get_session.return_value.execute.call_args + query = call_args[0][0] + compiled = query.compile(compile_kwargs={"literal_binds": True}) + sql = str(compiled) + + assert "LIMIT 500" in sql + assert "OFFSET 0" in sql From 75eafe551ea73010b441eccd3199fa32b31e6ecd Mon Sep 17 00:00:00 2001 From: Astor Date: Fri, 15 May 2026 12:15:55 -0300 Subject: [PATCH 21/58] refactor(TG-1041): address reviewer feedback on pagination implementation - Remove duplicate _paginate method from Entity that shadowed the existing one - Remove clear_cache method that had been re-added (was previously removed) - Revert list_for_suite to use cls._paginate() instead of inline query logic - Rename TestDefinition._paginate to select_page (public) and delegate to Entity._paginate - Fix re-render loop in on_filter_changed, on_page_changed, on_sort_changed handlers - Update unit tests to reflect renamed method and refactored Entity._paginate call path Co-Authored-By: Claude Sonnet 4.6 --- testgen/common/models/entity.py | 34 ---------- testgen/common/models/test_definition.py | 22 ++----- testgen/ui/views/test_definitions.py | 15 ++++- .../models/test_test_definition_pagination.py | 64 +++++++++++-------- 4 files changed, 57 insertions(+), 78 deletions(-) diff --git a/testgen/common/models/entity.py b/testgen/common/models/entity.py index 98671965..8f055bda 100644 --- a/testgen/common/models/entity.py +++ b/testgen/common/models/entity.py @@ -95,34 +95,6 @@ def select_where(cls, *clauses, order_by: tuple[str | InstrumentedAttribute] | N query = select(cls).where(*clauses).order_by(*order_by) return get_current_session().scalars(query).all() - @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) - def _paginate( - cls, - *clauses, - order_by: tuple[str | InstrumentedAttribute] | None = None, - page_index: int = 0, - page_size: int = 500, - ) -> tuple[list[Self], int]: - """Fetch one page of rows plus the total matching count via a window function. - - Uses ``COUNT(*) OVER()`` so only one round-trip to the database is needed. - Returns ``(items, total_count)``. ``page_index`` is 0-based. - """ - order_by = order_by or cls._default_order_by - total_col = func.count().over().label("total_count") - query = ( - select(cls, total_col) - .where(*clauses) - .order_by(*order_by) - .offset(page_index * page_size) - .limit(page_size) - ) - rows = get_current_session().execute(query).all() - items = [row[0] for row in rows] - total = rows[0][1] if rows else 0 - return items, total - @classmethod def select_minimal_where(cls, *clauses, order_by: tuple[str | InstrumentedAttribute]) -> Iterable[Any]: raise NotImplementedError @@ -192,12 +164,6 @@ def is_in_use(cls, ids: list[str]) -> bool: def cascade_delete(cls, ids: list[str]) -> None: raise NotImplementedError - @classmethod - def clear_cache(cls) -> None: - cls.get.clear() - cls.select_where.clear() - cls._paginate.clear() - @classmethod def columns(cls) -> list[str]: return list(cls.__annotations__.keys()) diff --git a/testgen/common/models/test_definition.py b/testgen/common/models/test_definition.py index 9a73717a..c7f9ca17 100644 --- a/testgen/common/models/test_definition.py +++ b/testgen/common/models/test_definition.py @@ -419,9 +419,8 @@ def list_for_suite( getattr(cls, col, None) or getattr(TestType, col) if isinstance(col, str) else col for col in cls._summary_columns ] - total_col = func.count().over().label("total_count") query = ( - select(*select_columns, total_col) + select(*select_columns) .join(TestType, cls.test_type == TestType.test_type) .join(TestSuite, cls.test_suite_id == TestSuite.id) .where(cls.test_suite_id == test_suite_id, TestSuite.is_monitor.isnot(True)) @@ -434,15 +433,12 @@ def list_for_suite( query = query.where(cls.test_type == test_type) if test_active is not None: query = query.where(cls.test_active == test_active) - query = query.order_by(*cls._default_order_by).offset((page - 1) * limit).limit(limit) - rows = get_current_session().execute(query).mappings().all() - items = [TestDefinitionSummary(**{k: v for k, v in row.items() if k != "total_count"}) for row in rows] - total = rows[0]["total_count"] if rows else 0 - return items, total + query = query.order_by(*cls._default_order_by) + return cls._paginate(query, page=page, limit=limit, data_class=TestDefinitionSummary) @classmethod @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) - def _paginate( + def select_page( cls, *clauses, order_by: tuple[str | InstrumentedAttribute] | None = None, @@ -453,19 +449,13 @@ def _paginate( getattr(cls, col, None) or getattr(TestType, col) if isinstance(col, str) else col for col in cls._summary_columns ] - total_col = func.count().over().label("total_count") query = ( - select(*select_columns, total_col) + select(*select_columns) .join(TestType, cls.test_type == TestType.test_type) .where(*clauses) .order_by(*(order_by or cls._default_order_by)) - .offset(page_index * page_size) - .limit(page_size) ) - rows = get_current_session().execute(query).mappings().all() - items = [TestDefinitionSummary(**{k: v for k, v in row.items() if k != "total_count"}) for row in rows] - total = rows[0]["total_count"] if rows else 0 - return items, total + return cls._paginate(query, page=page_index + 1, limit=page_size, data_class=TestDefinitionSummary) diff --git a/testgen/ui/views/test_definitions.py b/testgen/ui/views/test_definitions.py index e8d066d1..707bcc58 100644 --- a/testgen/ui/views/test_definitions.py +++ b/testgen/ui/views/test_definitions.py @@ -539,11 +539,22 @@ def on_export_selected(payload: dict) -> None: ) def on_filter_changed(filters: dict) -> None: + norm = lambda v: None if v in (None, "None", "") else str(v) + if ( + norm(filters.get("table_name")) == norm(table_name) + and norm(filters.get("column_name")) == norm(column_name) + and norm(filters.get("test_type")) == norm(test_type) + and norm(filters.get("flagged")) == norm(flagged) + and current_page == 0 + ): + return Router().set_query_params({**filters, "page": "0"}) def on_page_changed(payload: dict) -> None: new_page = payload.get("page", 0) new_page_size = payload.get("page_size") + if new_page == current_page and (new_page_size is None or int(new_page_size) == current_page_size): + return params: dict = {"page": str(new_page)} if new_page_size is not None: params["page_size"] = str(int(new_page_size)) @@ -557,6 +568,8 @@ def on_sort_changed(payload: dict) -> None: order = col.get("order", "asc") sort_parts.append(f"{field}:{order}") sort_value = ",".join(sort_parts) if sort_parts else None + if sort_value == sort and current_page == 0: + return Router().set_query_params({"sort": sort_value, "page": "0"}) testgen.test_definitions_widget( @@ -813,7 +826,7 @@ def get_test_definitions( order_by_tuple = tuple(order_by) if order_by else None if page_index is not None: - test_definitions, total_count = TestDefinition._paginate( + test_definitions, total_count = TestDefinition.select_page( *clauses, order_by=order_by_tuple, page_index=page_index, diff --git a/tests/unit/common/models/test_test_definition_pagination.py b/tests/unit/common/models/test_test_definition_pagination.py index 9f8edefd..fa844f90 100644 --- a/tests/unit/common/models/test_test_definition_pagination.py +++ b/tests/unit/common/models/test_test_definition_pagination.py @@ -11,11 +11,11 @@ @pytest.fixture(autouse=True) def clear_streamlit_cache(): - TestDefinition._paginate.clear() + TestDefinition.select_page.clear() yield -def _make_row(table_name: str = "my_table", total_count: int = 10) -> dict: +def _make_row(table_name: str = "my_table") -> dict: """Return a minimal row dict as returned by session.execute().mappings().all().""" return { # TestDefinitionSummary fields @@ -64,6 +64,7 @@ def _make_row(table_name: str = "my_table", total_count: int = 10) -> dict: "export_to_observability": False, "prediction": None, "flagged": False, + "impact_dimension": None, # TestTypeSummary fields "test_name_short": "Custom", "default_test_description": "A test", @@ -76,18 +77,19 @@ def _make_row(table_name: str = "my_table", total_count: int = 10) -> dict: "default_severity": "Warning", "test_scope": "column", "dq_dimension": "", + "default_impact_dimension": "", "usage_notes": "", - # Window function extra column - "total_count": total_count, } -@patch("testgen.common.models.test_definition.get_current_session") -def test__paginate_returns_items_and_total(mock_get_session): - rows = [_make_row("table_a", total_count=3), _make_row("table_b", total_count=3), _make_row("table_c", total_count=3)] - mock_get_session.return_value.execute.return_value.mappings.return_value.all.return_value = rows +@patch("testgen.common.models.entity.get_current_session") +def test_select_page_returns_items_and_total(mock_get_session): + rows = [_make_row("table_a"), _make_row("table_b"), _make_row("table_c")] + mock_session = mock_get_session.return_value + mock_session.scalar.return_value = 3 + mock_session.execute.return_value.mappings.return_value.all.return_value = rows - items, total = TestDefinition._paginate() + items, total = TestDefinition.select_page() assert total == 3 assert len(items) == 3 @@ -96,32 +98,38 @@ def test__paginate_returns_items_and_total(mock_get_session): assert items[2].table_name == "table_c" -@patch("testgen.common.models.test_definition.get_current_session") -def test__paginate_empty_result_returns_zero_total(mock_get_session): - mock_get_session.return_value.execute.return_value.mappings.return_value.all.return_value = [] +@patch("testgen.common.models.entity.get_current_session") +def test_select_page_empty_result_returns_zero_total(mock_get_session): + mock_session = mock_get_session.return_value + mock_session.scalar.return_value = 0 + mock_session.execute.return_value.mappings.return_value.all.return_value = [] - items, total = TestDefinition._paginate() + items, total = TestDefinition.select_page() assert items == [] assert total == 0 -@patch("testgen.common.models.test_definition.get_current_session") -def test__paginate_total_count_not_in_item_fields(mock_get_session): - mock_get_session.return_value.execute.return_value.mappings.return_value.all.return_value = [_make_row()] +@patch("testgen.common.models.entity.get_current_session") +def test_select_page_item_has_no_total_count_field(mock_get_session): + mock_session = mock_get_session.return_value + mock_session.scalar.return_value = 1 + mock_session.execute.return_value.mappings.return_value.all.return_value = [_make_row()] - items, _ = TestDefinition._paginate() + items, _ = TestDefinition.select_page() assert not hasattr(items[0], "total_count") -@patch("testgen.common.models.test_definition.get_current_session") -def test__paginate_uses_correct_offset_and_limit(mock_get_session): - mock_get_session.return_value.execute.return_value.mappings.return_value.all.return_value = [] +@patch("testgen.common.models.entity.get_current_session") +def test_select_page_uses_correct_offset_and_limit(mock_get_session): + mock_session = mock_get_session.return_value + mock_session.scalar.return_value = 0 + mock_session.execute.return_value.mappings.return_value.all.return_value = [] - TestDefinition._paginate(page_index=2, page_size=100) + TestDefinition.select_page(page_index=2, page_size=100) - call_args = mock_get_session.return_value.execute.call_args + call_args = mock_session.execute.call_args query = call_args[0][0] compiled = query.compile(compile_kwargs={"literal_binds": True}) sql = str(compiled) @@ -130,13 +138,15 @@ def test__paginate_uses_correct_offset_and_limit(mock_get_session): assert "OFFSET 200" in sql -@patch("testgen.common.models.test_definition.get_current_session") -def test__paginate_page_zero_has_no_offset(mock_get_session): - mock_get_session.return_value.execute.return_value.mappings.return_value.all.return_value = [] +@patch("testgen.common.models.entity.get_current_session") +def test_select_page_page_zero_has_no_offset(mock_get_session): + mock_session = mock_get_session.return_value + mock_session.scalar.return_value = 0 + mock_session.execute.return_value.mappings.return_value.all.return_value = [] - TestDefinition._paginate(page_index=0, page_size=500) + TestDefinition.select_page(page_index=0, page_size=500) - call_args = mock_get_session.return_value.execute.call_args + call_args = mock_session.execute.call_args query = call_args[0][0] compiled = query.compile(compile_kwargs={"literal_binds": True}) sql = str(compiled) From e33ef2f70a656ca7fb862a02e795c221b194f26b Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Fri, 15 May 2026 16:36:20 -0400 Subject: [PATCH 22/58] fix(scoring): accept leading-dot decimals in fn_eval MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Oracle's NUMBER → VARCHAR2 conversion drops the leading zero for values where |x| < 1 (e.g. ".733..." instead of "0.733..."). The value flows verbatim into test_results.result_measure (VARCHAR) and from there into the DQ scoring prevalence formula, which feeds fn_eval. The previous numeric token pattern required [0-9]+ at the start and rejected the leading-dot form as "invalid token \".\"". Loosen the pattern to `[0-9]+\.?[0-9]*|\.[0-9]+` so both `.733...` and `0.733...` parse. Includes the migration script (0189) that re-issues CREATE OR REPLACE FUNCTION for existing installs. --- .../020_create_standard_functions_sprocs.sql | 6 ++- .../dbupgrade/0189_incremental_upgrade.sql | 48 +++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) create mode 100644 testgen/template/dbupgrade/0189_incremental_upgrade.sql diff --git a/testgen/template/dbsetup/020_create_standard_functions_sprocs.sql b/testgen/template/dbsetup/020_create_standard_functions_sprocs.sql index 013343f0..57b83256 100644 --- a/testgen/template/dbsetup/020_create_standard_functions_sprocs.sql +++ b/testgen/template/dbsetup/020_create_standard_functions_sprocs.sql @@ -226,10 +226,12 @@ BEGIN RAISE EXCEPTION 'Invalid expression: dangerous statement detected'; END IF; - -- Remove all allowed tokens from the validation expression, treating 'FLOAT' as a keyword + -- Remove all allowed tokens from the validation expression, treating 'FLOAT' as a keyword. + -- Numeric pattern accepts leading-dot decimals (e.g. ".733") that Oracle emits + -- when converting NUMBER values with |x| < 1 to VARCHAR2. invalid_parts := regexp_replace( expression, - E'(\\mGREATEST|LEAST|ABS|FN_NORMAL_CDF|DATEDIFF|DAY|FLOAT|NULLIF)\\M|[0-9]+(\\.[0-9]+)?([eE][+-]?[0-9]+)?|[+\\-*/(),\\\'":]+|\\s+', + E'(\\mGREATEST|LEAST|ABS|FN_NORMAL_CDF|DATEDIFF|DAY|FLOAT|NULLIF)\\M|([0-9]+\\.?[0-9]*|\\.[0-9]+)([eE][+-]?[0-9]+)?|[+\\-*/(),\\\'":]+|\\s+', '', 'gi' ); diff --git a/testgen/template/dbupgrade/0189_incremental_upgrade.sql b/testgen/template/dbupgrade/0189_incremental_upgrade.sql new file mode 100644 index 00000000..96227490 --- /dev/null +++ b/testgen/template/dbupgrade/0189_incremental_upgrade.sql @@ -0,0 +1,48 @@ +SET SEARCH_PATH TO {SCHEMA_NAME}; + +-- Loosen fn_eval's numeric token pattern to accept leading-dot decimals +-- (e.g. ".733"). Oracle's NUMBER -> VARCHAR2 conversion drops the leading +-- zero for |x| < 1, and the value flows verbatim into test_results.result_measure +-- (VARCHAR), so the DQ scoring prevalence formula like +-- 2.0 * (1.0 - fn_normal_cdf(ABS({RESULT_MEASURE}::FLOAT) / 2.0)) +-- fed ".733..." to fn_eval, which rejected it as "invalid token \".\"". + +CREATE OR REPLACE FUNCTION {SCHEMA_NAME}.fn_eval(expression TEXT) RETURNS FLOAT +AS +$$ +DECLARE + result FLOAT; + invalid_parts TEXT; +BEGIN + -- Check the modified expression for invalid characters, allowing colons + IF expression ~* E'[^0-9+\\-*/(),.\\sA-Z_:e\\\'"]' THEN + RAISE EXCEPTION 'Invalid characters detected in expression: %', expression; + END IF; + + -- Check for dangerous PostgreSQL-specific keywords + IF expression ~* E'\b(DROP|ALTER|INSERT|UPDATE|DELETE|TRUNCATE|GRANT|REVOKE|COPY|EXECUTE|CREATE|COMMENT|SECURITY|WITH|SET ROLE|SET SESSION|DO|CALL|--|/\\*|;|pg_read_file|pg_write_file|pg_terminate_backend)\b' THEN + RAISE EXCEPTION 'Invalid expression: dangerous statement detected'; + END IF; + + -- Remove all allowed tokens from the validation expression, treating 'FLOAT' as a keyword. + -- Numeric pattern accepts leading-dot decimals (e.g. ".733") that Oracle emits + -- when converting NUMBER values with |x| < 1 to VARCHAR2. + invalid_parts := regexp_replace( + expression, + E'(\\mGREATEST|LEAST|ABS|FN_NORMAL_CDF|DATEDIFF|DAY|FLOAT|NULLIF)\\M|([0-9]+\\.?[0-9]*|\\.[0-9]+)([eE][+-]?[0-9]+)?|[+\\-*/(),\\\'":]+|\\s+', + '', + 'gi' + ); + + -- If anything is left in the validation expression, it's invalid + IF invalid_parts <> '' THEN + RAISE EXCEPTION 'Invalid expression contains invalid tokens "%" in expression: %', invalid_parts, expression; + END IF; + + -- Use the original expression (with ::FLOAT) for execution + EXECUTE format('SELECT (%s)::FLOAT', expression) INTO result; + + RETURN result; +END; +$$ +LANGUAGE plpgsql; From ce8cca3fd623e2362b28bc011d3ce3300d28a71b Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Fri, 15 May 2026 15:49:25 -0400 Subject: [PATCH 23/58] refactor(mcp): apply TG-1067 review feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Argument rename: list_column_profiles arg `functional_data_type` → `semantic_data_type`, matching the user-facing display label. - Score scale: `score_profiling_*` / `score_testing_*` filter args now accept the 0-100 scale shown in the output; division to the 0-1 DB scale happens at the WHERE boundary. - SuggestedDataType enum gains `Smallint`, `Bigint`, `Decimal` (per the vocabulary the engine emits via `datatype_suggestions.sql`). - `get_column_frequent_values` / `get_column_patterns` source `pii_flag` from `DataColumnChars` (canonical) rather than `ProfileResult`. Header label `Records` → `Row Count` to match `get_column_profile_detail`. - `DataColumnChars.list_for_table_group` appends a deterministic tiebreaker `(table_name, ordinal_position, column_name)` to every non-default ORDER BY so paginated callers see no row skipping or duplication across pages. - Tighter `list_column_profiles` arg docstring (drops the misleading "open-ended" framing and the multi-call strategy hint). Co-Authored-By: Claude Opus 4.7 (1M context) --- testgen/common/models/data_column.py | 23 +++-- testgen/mcp/tools/profiling.py | 78 ++++++++-------- testgen/mcp/tools/reference.py | 29 ++++++ tests/unit/mcp/test_tools_profiling.py | 118 +++++++++++++++++++++---- 4 files changed, 188 insertions(+), 60 deletions(-) diff --git a/testgen/common/models/data_column.py b/testgen/common/models/data_column.py index cf350519..4027c2ca 100644 --- a/testgen/common/models/data_column.py +++ b/testgen/common/models/data_column.py @@ -53,7 +53,10 @@ class SuggestedDataType(StrEnum): """Values accepted for the ``suggested_data_type`` argument.""" ANY = "Any" + SMALLINT = "Smallint" INTEGER = "Integer" + BIGINT = "Bigint" + DECIMAL = "Decimal" NUMERIC = "Numeric" VARCHAR = "Varchar" DATE = "Date" @@ -65,7 +68,10 @@ class SuggestedDataType(StrEnum): # ``datatype_suggestion`` (``Any`` is a sentinel — no prefix, just non-null check). SUGGESTED_DATA_TYPE_TO_PREFIX: dict[SuggestedDataType, str | None] = { SuggestedDataType.ANY: None, + SuggestedDataType.SMALLINT: "SMALLINT", SuggestedDataType.INTEGER: "INTEGER", + SuggestedDataType.BIGINT: "BIGINT", + SuggestedDataType.DECIMAL: "DECIMAL", SuggestedDataType.NUMERIC: "NUMERIC", SuggestedDataType.VARCHAR: "VARCHAR", SuggestedDataType.DATE: "DATE", @@ -323,21 +329,24 @@ def list_for_table_group( null_ratio_expr = ProfileResult.null_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) distinct_ratio_expr = ProfileResult.distinct_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) filled_ratio_expr = ProfileResult.filled_value_ct * 1.0 / func.nullif(ProfileResult.record_ct, 0) + # Deterministic tiebreaker so paginated callers don't see rows skip or duplicate + # across pages when the primary sort has ties. + tiebreaker = (asc(cls.table_name), asc(cls.ordinal_position), asc(cls.column_name)) order_exprs: tuple if order_by is ColumnOrderBy.NULL_RATIO: - order_exprs = (desc(null_ratio_expr).nulls_last(),) + order_exprs = (desc(null_ratio_expr).nulls_last(), *tiebreaker) elif order_by is ColumnOrderBy.DISTINCT_RATIO: - order_exprs = (asc(distinct_ratio_expr).nulls_last(),) + order_exprs = (asc(distinct_ratio_expr).nulls_last(), *tiebreaker) elif order_by is ColumnOrderBy.FILLED_RATIO: - order_exprs = (desc(filled_ratio_expr).nulls_last(),) + order_exprs = (desc(filled_ratio_expr).nulls_last(), *tiebreaker) elif order_by is ColumnOrderBy.SCORE_PROFILING: - order_exprs = (asc(cls.dq_score_profiling).nulls_last(),) + order_exprs = (asc(cls.dq_score_profiling).nulls_last(), *tiebreaker) elif order_by is ColumnOrderBy.SCORE_TESTING: - order_exprs = (asc(cls.dq_score_testing).nulls_last(),) + order_exprs = (asc(cls.dq_score_testing).nulls_last(), *tiebreaker) elif order_by is ColumnOrderBy.HYGIENE_COUNT: - order_exprs = (desc(func.coalesce(hygiene_subq.c.hygiene_issue_count, 0)),) + order_exprs = (desc(func.coalesce(hygiene_subq.c.hygiene_issue_count, 0)), *tiebreaker) else: - order_exprs = (asc(cls.table_name), asc(cls.ordinal_position), asc(cls.column_name)) + order_exprs = tiebreaker query = query.order_by(*order_exprs) diff --git a/testgen/mcp/tools/profiling.py b/testgen/mcp/tools/profiling.py index 4e3034f7..2f01de89 100644 --- a/testgen/mcp/tools/profiling.py +++ b/testgen/mcp/tools/profiling.py @@ -111,7 +111,7 @@ def list_column_profiles( cde: bool | None = None, suggested_data_type: str | None = None, general_type: str | None = None, - functional_data_type: str | None = None, + semantic_data_type: str | None = None, pii_category: str | None = None, pii_risk_level: str | None = None, order_by: str | None = None, @@ -137,27 +137,23 @@ def list_column_profiles( this value. filled_ratio_below: Match columns whose dummy/placeholder-value fraction is below this value. - score_profiling_above: Match columns whose Profiling Score is above this value. - score_profiling_below: Match columns whose Profiling Score is below this value. - score_testing_above: Match columns whose Testing Score is above this value. - score_testing_below: Match columns whose Testing Score is below this value. + score_profiling_above: Match columns whose Profiling Score is above this value (0-100 scale). + score_profiling_below: Match columns whose Profiling Score is below this value (0-100 scale). + score_testing_above: Match columns whose Testing Score is above this value (0-100 scale). + score_testing_below: Match columns whose Testing Score is below this value (0-100 scale). pii: When `true`, match columns flagged as PII; when `false`, exclude PII columns. cde: When `true`, match columns flagged as a Critical Data Element (directly or inherited from the table); when `false`, exclude CDE columns. suggested_data_type: Match columns where profiling suggests a more suitable data - type. Pass `Any` for any mismatch, or a concrete type (`Integer`, `Numeric`, - `Varchar`, `Date`, `Timestamp`, `Boolean`) to filter mismatches whose - suggestion starts with that type. Columns where the suggestion matches the - column's stored type are always excluded. + type. Pass `Any` for any mismatch, or a concrete type (`Smallint`, `Integer`, + `Bigint`, `Decimal`, `Numeric`, `Varchar`, `Date`, `Timestamp`, `Boolean`) to + filter mismatches whose suggestion starts with that type. Columns where the + suggestion matches the column's stored type are always excluded. general_type: Broad type classification — `Alpha`, `Numeric`, `Datetime`, `Boolean`, `Time`, or `Other`. - functional_data_type: Substring match (case-insensitive) on Semantic Data Type. - Use a cluster prefix to catch related variants — `Period` matches - `Period`, `Period Month`, `Period Year`, etc.; `ID` matches `ID`, `ID-FK`, - `ID-Unique`, etc.; `Transactional Date` matches all of its variants. Bare - tokens auto-wrap with `%`; an explicit `%` in the input is honored as a - wildcard. The set of values is open-ended — discover available values by - listing columns without this filter, then narrow. + semantic_data_type: Substring match (case-insensitive) on Semantic Data Type. + Bare tokens auto-wrap with `%`; an explicit `%` is honored as a wildcard. + See `testgen://column-profile-fields` for the canonical value list. pii_category: PII category — `ID`, `Name`, `Demographic`, or `Contact`. pii_risk_level: PII risk level — `High`, `Moderate`, or `Low`. order_by: Sort key — `Null Ratio`, `Distinct Ratio`, `Filled Ratio`, @@ -206,13 +202,13 @@ def list_column_profiles( ) if score_profiling_above is not None: - clauses.append(DataColumnChars.dq_score_profiling > score_profiling_above) + clauses.append(DataColumnChars.dq_score_profiling > score_profiling_above / 100) if score_profiling_below is not None: - clauses.append(DataColumnChars.dq_score_profiling < score_profiling_below) + clauses.append(DataColumnChars.dq_score_profiling < score_profiling_below / 100) if score_testing_above is not None: - clauses.append(DataColumnChars.dq_score_testing > score_testing_above) + clauses.append(DataColumnChars.dq_score_testing > score_testing_above / 100) if score_testing_below is not None: - clauses.append(DataColumnChars.dq_score_testing < score_testing_below) + clauses.append(DataColumnChars.dq_score_testing < score_testing_below / 100) if pii is True: clauses.append(DataColumnChars.pii_flag.isnot(None)) @@ -244,12 +240,12 @@ def list_column_profiles( if general_type is not None: clauses.append(DataColumnChars.general_type == parse_general_type(general_type)) - if functional_data_type is not None: - if not functional_data_type.strip(): - raise MCPUserError("`functional_data_type` cannot be empty.") + if semantic_data_type is not None: + if not semantic_data_type.strip(): + raise MCPUserError("`semantic_data_type` cannot be empty.") clauses.append( DataColumnChars.functional_data_type.ilike( - build_ilike_pattern(functional_data_type), escape="\\" + build_ilike_pattern(semantic_data_type), escape="\\" ) ) if pii_category is not None: @@ -659,8 +655,12 @@ def _load_profile_for_column( table_name: str, column_name: str, job_execution_id: str | None, -) -> tuple[ProfileResult, ProfilingRun]: - """Resolve and load the profile-results row for one column, paired with its ``ProfilingRun``.""" +) -> tuple[ProfileResult, ProfilingRun, str | None]: + """Resolve and load the profile-results row for one column. + + Returns a triple of ``(profile, profiling_run, pii_flag)`` where ``pii_flag`` is + pulled from ``data_column_chars`` (the source of truth for column-level PII state). + """ profiling_run: ProfilingRun | None = None if job_execution_id: profiling_run = resolve_profiling_run(job_execution_id) @@ -678,12 +678,18 @@ def _load_profile_for_column( profiling_run = ProfilingRun.get(profile.profile_run_id) if profiling_run is None: raise MCPResourceNotAccessible("Profiling run", str(profile.profile_run_id)) - return profile, profiling_run + column_rows = list(DataColumnChars.select_where( + DataColumnChars.table_groups_id == tg.id, + DataColumnChars.table_name == table_name, + DataColumnChars.column_name == column_name, + )) + pii_flag = column_rows[0].pii_flag if column_rows else None + return profile, profiling_run, pii_flag -def _is_pii_redacted_for_caller(tg: TableGroup, profile: ProfileResult) -> bool: +def _is_pii_redacted_for_caller(tg: TableGroup, pii_flag: str | None) -> bool: """Decide whether to redact PII values for this caller + column.""" - if not profile.pii_flag: + if not pii_flag: return False return not get_project_permissions().has_permission("view_pii", tg.project_code) @@ -937,16 +943,16 @@ def get_column_frequent_values( latest profile run. """ tg = resolve_table_group(table_group_id) - profile, profiling_run = _load_profile_for_column(tg, table_name, column_name, job_execution_id) + profile, profiling_run, pii_flag = _load_profile_for_column(tg, table_name, column_name, job_execution_id) doc = MdDoc() doc.heading(1, f"Frequent values: {table_name}.{column_name}") doc.field("Table group", tg.id, code=True) doc.field("Profiling Run", profiling_run.job_execution_id, code=True) - doc.field("Records", profile.record_ct) + doc.field("Row Count", profile.record_ct) doc.field("Distinct values", profile.distinct_value_ct) - if profile.pii_flag: - doc.field("PII", _format_pii(profile.pii_flag)) + if pii_flag: + doc.field("PII", _format_pii(pii_flag)) rows = parse_top_freq_values(profile.top_freq_values) if not rows: @@ -956,7 +962,7 @@ def get_column_frequent_values( ) return doc.render() - redact = _is_pii_redacted_for_caller(tg, profile) + redact = _is_pii_redacted_for_caller(tg, pii_flag) record_ct = profile.record_ct or 0 display_rows: list[list[object]] = [] for value, count in rows: @@ -992,13 +998,13 @@ def get_column_patterns( latest profile run. """ tg = resolve_table_group(table_group_id) - profile, profiling_run = _load_profile_for_column(tg, table_name, column_name, job_execution_id) + profile, profiling_run, _ = _load_profile_for_column(tg, table_name, column_name, job_execution_id) doc = MdDoc() doc.heading(1, f"Character patterns: {table_name}.{column_name}") doc.field("Table group", tg.id, code=True) doc.field("Profiling Run", profiling_run.job_execution_id, code=True) - doc.field("Records", profile.record_ct) + doc.field("Row Count", profile.record_ct) doc.field("Distinct values", profile.distinct_value_ct) if profile.general_type and profile.general_type != "A": diff --git a/testgen/mcp/tools/reference.py b/testgen/mcp/tools/reference.py index d9aa75bc..655c257e 100644 --- a/testgen/mcp/tools/reference.py +++ b/testgen/mcp/tools/reference.py @@ -227,6 +227,35 @@ def column_profile_fields_resource() -> str: Aggregates, counts, `Frequent Patterns`, and `Standard Pattern Match` are never redacted — they're distribution-level signals that don't expose individual rows. + +## Semantic Data Type — values emitted by profiling, grouped by family. + +**Identifiers**: `ID`, `ID-FK`, `ID-Group`, `ID-Secondary`, `ID-SK`, +`ID-Unique`, `ID-Unique-SK` + +**Dates & schedules**: `Date Stamp`, `DateTime Stamp`, `Schedule Date`, +`Future Date`, `Historical Date`, `Transactional Date`, +`Transactional Date (Mo)`, `Transactional Date (Qtr)`, +`Transactional Date (Wk)` + +**Periods**: `Period`, `Period DOW`, `Period Mon-NN`, `Period Month`, +`Period Quarter`, `Period Week`, `Period Year`, `Period Year-Mon` + +**People**: `Person Full Name`, `Person Given Name`, `Person Last Name` + +**Location & contact**: `Address`, `City`, `State`, `Zip`, `Email`, `Phone` + +**Measurements**: `Measurement`, `Measurement Discrete`, `Measurement Pct`, +`Measurement Spike`, `Measurement Text` + +**Codes, flags, attributes**: `Attribute`, `Boolean`, `Code`, `Constant`, +`Flag`, `Sequence` + +**Entity & system**: `Entity Name`, `Process`, `Process User`, `System User` + +The `semantic_data_type` filter on `list_column_profiles` matches via `ILIKE`, +so partial inputs catch related variants (e.g. `ID` matches `ID`, `ID-FK`, +`ID-Group`, …). """ diff --git a/tests/unit/mcp/test_tools_profiling.py b/tests/unit/mcp/test_tools_profiling.py index 87c894f8..588bd42d 100644 --- a/tests/unit/mcp/test_tools_profiling.py +++ b/tests/unit/mcp/test_tools_profiling.py @@ -1331,6 +1331,37 @@ def test_list_column_profiles_null_ratio_above_adds_clause(mock_tg_cls, mock_dcc assert any("null_value_ct" in str(c) for c in clauses) +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_score_profiling_above_converts_to_0_to_1_scale( + mock_tg_cls, mock_method, db_session_mock, +): + """The user-facing 0-100 score range maps to the 0-1 fraction the DB stores.""" + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), score_profiling_above=70) + + sql = _compile_clauses(mock_method) + assert "dq_score_profiling > 0.7" in sql + + +@patch.object(DataColumnChars, "list_for_table_group") +@patch("testgen.mcp.tools.common.TableGroup") +def test_list_column_profiles_score_testing_below_converts_to_0_to_1_scale( + mock_tg_cls, mock_method, db_session_mock, +): + mock_tg_cls.get.return_value = _mock_table_group() + mock_method.return_value = ([], 0) + + from testgen.mcp.tools.profiling import list_column_profiles + list_column_profiles(str(uuid4()), score_testing_below=50) + + sql = _compile_clauses(mock_method) + assert "dq_score_testing < 0.5" in sql + + @patch.object(DataColumnChars, "list_for_table_group") @patch("testgen.mcp.tools.common.TableGroup") def test_list_column_profiles_pii_true_adds_is_not_null_clause(mock_tg_cls, mock_method, db_session_mock): @@ -1454,14 +1485,14 @@ def test_list_column_profiles_pii_risk_level_moderate_does_not_include_manual( @patch.object(DataColumnChars, "list_for_table_group") @patch("testgen.mcp.tools.common.TableGroup") -def test_list_column_profiles_functional_data_type_uses_ilike( +def test_list_column_profiles_semantic_data_type_uses_ilike( mock_tg_cls, mock_method, db_session_mock, ): mock_tg_cls.get.return_value = _mock_table_group() mock_method.return_value = ([], 0) from testgen.mcp.tools.profiling import list_column_profiles - list_column_profiles(str(uuid4()), functional_data_type="Person Given") + list_column_profiles(str(uuid4()), semantic_data_type="Person Given") sql = _compile_clauses(mock_method) # Default dialect renders ILIKE as ``LOWER(col) LIKE LOWER(pat) ESCAPE`` — same semantic. @@ -1471,7 +1502,7 @@ def test_list_column_profiles_functional_data_type_uses_ilike( @patch.object(DataColumnChars, "list_for_table_group") @patch("testgen.mcp.tools.common.TableGroup") -def test_list_column_profiles_functional_data_type_underscore_escaped( +def test_list_column_profiles_semantic_data_type_underscore_escaped( mock_tg_cls, mock_method, db_session_mock, ): """Underscores in the input must be escaped (column names commonly contain them).""" @@ -1479,7 +1510,7 @@ def test_list_column_profiles_functional_data_type_underscore_escaped( mock_method.return_value = ([], 0) from testgen.mcp.tools.profiling import list_column_profiles - list_column_profiles(str(uuid4()), functional_data_type="ID_FK") + list_column_profiles(str(uuid4()), semantic_data_type="ID_FK") sql = _compile_clauses(mock_method) # The escape clause appears, and the underscore is escaped in the pattern. @@ -1488,12 +1519,12 @@ def test_list_column_profiles_functional_data_type_underscore_escaped( @patch.object(DataColumnChars, "list_for_table_group") @patch("testgen.mcp.tools.common.TableGroup") -def test_list_column_profiles_functional_data_type_empty_rejected(mock_tg_cls, mock_method, db_session_mock): +def test_list_column_profiles_semantic_data_type_empty_rejected(mock_tg_cls, mock_method, db_session_mock): mock_tg_cls.get.return_value = _mock_table_group() from testgen.mcp.tools.profiling import list_column_profiles - with pytest.raises(MCPUserError, match="`functional_data_type` cannot be empty"): - list_column_profiles(str(uuid4()), functional_data_type=" ") + with pytest.raises(MCPUserError, match="`semantic_data_type` cannot be empty"): + list_column_profiles(str(uuid4()), semantic_data_type=" ") @patch.object(DataColumnChars, "list_for_table_group") @@ -1543,14 +1574,25 @@ def _mock_profiling_run_for_tg(tg_id): return pr +def _mock_data_column(pii_flag=None): + """Build a mock `DataColumnChars` row carrying just the fields the helper reads.""" + col = MagicMock() + col.pii_flag = pii_flag + return col + + +@patch.object(DataColumnChars, "select_where") @patch("testgen.mcp.tools.profiling.ProfilingRun") @patch("testgen.mcp.tools.profiling.ProfileResult") @patch("testgen.mcp.tools.common.TableGroup") -def test_get_column_frequent_values_happy_path(mock_tg_cls, mock_pr_cls, mock_run_cls, db_session_mock): +def test_get_column_frequent_values_happy_path( + mock_tg_cls, mock_pr_cls, mock_run_cls, mock_dcc_select, db_session_mock, +): tg = _mock_table_group() mock_tg_cls.get.return_value = tg mock_pr_cls.get_for_column.return_value = _mock_profile_result() mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + mock_dcc_select.return_value = [_mock_data_column()] from testgen.mcp.tools.profiling import get_column_frequent_values result = get_column_frequent_values(str(uuid4()), "customers", "country") @@ -1561,11 +1603,12 @@ def test_get_column_frequent_values_happy_path(mock_tg_cls, mock_pr_cls, mock_ru assert "Top values" in result +@patch.object(DataColumnChars, "select_where") @patch("testgen.mcp.tools.profiling.ProfilingRun") @patch("testgen.mcp.tools.profiling.ProfileResult") @patch("testgen.mcp.tools.common.TableGroup") def test_get_column_frequent_values_surfaces_job_execution_id_not_profile_run_id( - mock_tg_cls, mock_pr_cls, mock_run_cls, db_session_mock, + mock_tg_cls, mock_pr_cls, mock_run_cls, mock_dcc_select, db_session_mock, ): tg = _mock_table_group() mock_tg_cls.get.return_value = tg @@ -1573,6 +1616,7 @@ def test_get_column_frequent_values_surfaces_job_execution_id_not_profile_run_id mock_pr_cls.get_for_column.return_value = profile run = _mock_profiling_run_for_tg(tg.id) mock_run_cls.get.return_value = run + mock_dcc_select.return_value = [_mock_data_column()] from testgen.mcp.tools.profiling import get_column_frequent_values result = get_column_frequent_values(str(uuid4()), "customers", "country") @@ -1582,19 +1626,21 @@ def test_get_column_frequent_values_surfaces_job_execution_id_not_profile_run_id assert str(profile.profile_run_id) not in result +@patch.object(DataColumnChars, "select_where") @patch("testgen.mcp.tools.profiling.ProfilingRun") @patch("testgen.mcp.tools.profiling.ProfileResult") @patch("testgen.mcp.tools.common.TableGroup") def test_get_column_frequent_values_pii_value_redacted_when_caller_lacks_view_pii( - mock_tg_cls, mock_pr_cls, mock_run_cls, db_session_mock, + mock_tg_cls, mock_pr_cls, mock_run_cls, mock_dcc_select, db_session_mock, ): tg = _mock_table_group(project_code="demo") mock_tg_cls.get.return_value = tg mock_pr_cls.get_for_column.return_value = _mock_profile_result( - pii_flag="B/CONTACT/Email", top_freq_values="| alice@example.com | 5\n| bob@example.com | 3", ) mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + # The pii_flag the tool reads comes from DataColumnChars, not ProfileResult. + mock_dcc_select.return_value = [_mock_data_column(pii_flag="B/CONTACT/Email")] # Default test conftest grants no view_pii (TEST_PERM_MATRIX has no entry). from testgen.mcp.tools.profiling import get_column_frequent_values @@ -1604,20 +1650,21 @@ def test_get_column_frequent_values_pii_value_redacted_when_caller_lacks_view_pi assert "alice@example.com" not in result +@patch.object(DataColumnChars, "select_where") @patch("testgen.mcp.permissions._compute_project_permissions") @patch("testgen.mcp.tools.profiling.ProfilingRun") @patch("testgen.mcp.tools.profiling.ProfileResult") @patch("testgen.mcp.tools.common.TableGroup") def test_get_column_frequent_values_pii_value_visible_with_view_pii_grant( - mock_tg_cls, mock_pr_cls, mock_run_cls, mock_compute, db_session_mock, + mock_tg_cls, mock_pr_cls, mock_run_cls, mock_compute, mock_dcc_select, db_session_mock, ): tg = _mock_table_group(project_code="demo") mock_tg_cls.get.return_value = tg mock_pr_cls.get_for_column.return_value = _mock_profile_result( - pii_flag="B/CONTACT/Email", top_freq_values="| alice@example.com | 5\n| bob@example.com | 3", ) mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + mock_dcc_select.return_value = [_mock_data_column(pii_flag="B/CONTACT/Email")] mock_compute.return_value = ProjectPermissions( memberships={"demo": "role_a"}, permission="catalog", @@ -1633,11 +1680,12 @@ def test_get_column_frequent_values_pii_value_visible_with_view_pii_grant( assert PII_REDACTED not in result +@patch.object(DataColumnChars, "select_where") @patch("testgen.mcp.tools.profiling.ProfilingRun") @patch("testgen.mcp.tools.profiling.ProfileResult") @patch("testgen.mcp.tools.common.TableGroup") def test_get_column_frequent_values_high_cardinality_fallback( - mock_tg_cls, mock_pr_cls, mock_run_cls, db_session_mock, + mock_tg_cls, mock_pr_cls, mock_run_cls, mock_dcc_select, db_session_mock, ): tg = _mock_table_group() mock_tg_cls.get.return_value = tg @@ -1645,6 +1693,7 @@ def test_get_column_frequent_values_high_cardinality_fallback( top_freq_values=None, distinct_value_ct=10000, ) mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + mock_dcc_select.return_value = [_mock_data_column()] from testgen.mcp.tools.profiling import get_column_frequent_values result = get_column_frequent_values(str(uuid4()), "customers", "customer_id") @@ -1666,15 +1715,45 @@ def test_get_column_frequent_values_missing_profile_raises_not_accessible( get_column_frequent_values(str(uuid4()), "customers", "ghost") +@patch.object(DataColumnChars, "select_where") +@patch("testgen.mcp.tools.profiling.ProfilingRun") +@patch("testgen.mcp.tools.profiling.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_column_frequent_values_pii_source_is_data_column_chars_not_profile_result( + mock_tg_cls, mock_pr_cls, mock_run_cls, mock_dcc_select, db_session_mock, +): + """``data_column_chars.pii_flag`` is the source of truth; ``profile_result.pii_flag`` is ignored.""" + tg = _mock_table_group(project_code="demo") + mock_tg_cls.get.return_value = tg + # ProfileResult carries a stale/wrong pii_flag; DataColumnChars says None. + mock_pr_cls.get_for_column.return_value = _mock_profile_result( + pii_flag="A/CONTACT/Email", # stale value; should NOT drive redaction + top_freq_values="| alice@example.com | 5", + ) + mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + mock_dcc_select.return_value = [_mock_data_column(pii_flag=None)] + + from testgen.mcp.tools.profiling import get_column_frequent_values + result = get_column_frequent_values(str(uuid4()), "customers", "email") + + # No redaction, no PII field — because DataColumnChars says the column is not PII. + assert PII_REDACTED not in result + assert "alice@example.com" in result + assert "PII" not in result.splitlines()[1:6] # no "PII:" field in the header block + + # ---------------------------------------------------------------------- # get_column_patterns # ---------------------------------------------------------------------- +@patch.object(DataColumnChars, "select_where") @patch("testgen.mcp.tools.profiling.ProfilingRun") @patch("testgen.mcp.tools.profiling.ProfileResult") @patch("testgen.mcp.tools.common.TableGroup") -def test_get_column_patterns_happy_path(mock_tg_cls, mock_pr_cls, mock_run_cls, db_session_mock): +def test_get_column_patterns_happy_path( + mock_tg_cls, mock_pr_cls, mock_run_cls, mock_dcc_select, db_session_mock, +): tg = _mock_table_group() mock_tg_cls.get.return_value = tg mock_pr_cls.get_for_column.return_value = _mock_profile_result( @@ -1682,6 +1761,7 @@ def test_get_column_patterns_happy_path(mock_tg_cls, mock_pr_cls, mock_run_cls, top_patterns="326 | Aaaaaa | 176 | AAA", ) mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + mock_dcc_select.return_value = [_mock_data_column()] from testgen.mcp.tools.profiling import get_column_patterns result = get_column_patterns(str(uuid4()), "customers", "country") @@ -1691,11 +1771,12 @@ def test_get_column_patterns_happy_path(mock_tg_cls, mock_pr_cls, mock_run_cls, assert "Top patterns" in result +@patch.object(DataColumnChars, "select_where") @patch("testgen.mcp.tools.profiling.ProfilingRun") @patch("testgen.mcp.tools.profiling.ProfileResult") @patch("testgen.mcp.tools.common.TableGroup") def test_get_column_patterns_non_string_column_fallback( - mock_tg_cls, mock_pr_cls, mock_run_cls, db_session_mock, + mock_tg_cls, mock_pr_cls, mock_run_cls, mock_dcc_select, db_session_mock, ): tg = _mock_table_group() mock_tg_cls.get.return_value = tg @@ -1704,6 +1785,7 @@ def test_get_column_patterns_non_string_column_fallback( top_patterns=None, ) mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + mock_dcc_select.return_value = [_mock_data_column()] from testgen.mcp.tools.profiling import get_column_patterns result = get_column_patterns(str(uuid4()), "products", "price") @@ -1711,11 +1793,12 @@ def test_get_column_patterns_non_string_column_fallback( assert "column is not a string type" in result +@patch.object(DataColumnChars, "select_where") @patch("testgen.mcp.tools.profiling.ProfilingRun") @patch("testgen.mcp.tools.profiling.ProfileResult") @patch("testgen.mcp.tools.common.TableGroup") def test_get_column_patterns_high_cardinality_fallback( - mock_tg_cls, mock_pr_cls, mock_run_cls, db_session_mock, + mock_tg_cls, mock_pr_cls, mock_run_cls, mock_dcc_select, db_session_mock, ): tg = _mock_table_group() mock_tg_cls.get.return_value = tg @@ -1725,6 +1808,7 @@ def test_get_column_patterns_high_cardinality_fallback( distinct_value_ct=9999, ) mock_run_cls.get.return_value = _mock_profiling_run_for_tg(tg.id) + mock_dcc_select.return_value = [_mock_data_column()] from testgen.mcp.tools.profiling import get_column_patterns result = get_column_patterns(str(uuid4()), "customers", "address") From 7eb4c2c33bf76811f1dc48e2f3533412e20e6295 Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Thu, 14 May 2026 10:07:56 -0400 Subject: [PATCH 24/58] =?UTF-8?q?feat(mcp):=20profiling=20L4=20=E2=80=94?= =?UTF-8?q?=20cross-run=20comparison,=20trends,=20schema=20history=20(TG-1?= =?UTF-8?q?068)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three MCP tools that operate across profiling runs of the same table group. Tools: - compare_profiling_runs(target_*, baseline_*?, table?, column?) — diff two completed runs. When `baseline_*` is omitted, defaults to the previous completed run on the same table group. Renders a two-row header table (Job ID / Started for Target vs Baseline) followed by per-general-type metric tables (Numeric / Alpha / Date / Boolean) for shared columns, categorical changes as bullets, and hygiene-issue churn (Added / Resolved). A one-line note flags structural drift with a pointer; the per-table/column structural diff is delegated to get_schema_history. - get_profiling_trends(table_group_id, metrics, table?, column?, limit?) — caller-named metric time-series across the last N completed runs. Sparse rendering for runs where the entity didn't exist; first-appears / last-appears notes name the bounds of the entity's presence in the window. Metric vocabulary is a new ProfileMetric StrEnum. - get_schema_history(table_group_id, limit?) — per-run timeline of tables/columns added or dropped, type changes, and record-count deltas per surviving table. Helpers: - ProfileResult.select_for_runs(run_ids, table?, column?) — single query loading rows for multiple runs. - ProfilingRun.list_recent_complete(tg_id, limit) — newest-first completed runs for a table group. - ProfilingRun.count_confirmed_hygiene_issues(run_ids) — confirmed hygiene issue counts per run. - parse_profile_metrics() in mcp/tools/common.py — validate a list of caller-supplied metric names against the ProfileMetric vocabulary. Includes 38 new unit tests. Co-Authored-By: Claude Opus 4.7 (1M context) --- testgen/common/models/data_column.py | 33 + testgen/common/models/profile_result.py | 23 + testgen/common/models/profiling_run.py | 32 + testgen/mcp/server.py | 8 + testgen/mcp/tools/common.py | 18 + testgen/mcp/tools/profile_history.py | 866 +++++++++++++++++++ tests/unit/mcp/test_tools_profile_history.py | 570 ++++++++++++ 7 files changed, 1550 insertions(+) create mode 100644 testgen/mcp/tools/profile_history.py create mode 100644 tests/unit/mcp/test_tools_profile_history.py diff --git a/testgen/common/models/data_column.py b/testgen/common/models/data_column.py index 4027c2ca..e0e5d8ec 100644 --- a/testgen/common/models/data_column.py +++ b/testgen/common/models/data_column.py @@ -91,6 +91,39 @@ class ColumnOrderBy(StrEnum): HYGIENE_COUNT = "Hygiene Count" +class ProfileMetric(StrEnum): + """Profile-metric vocabulary: linear/arithmetic stats from a profiling run. + + Covers general column ratios (null / distinct / filled), type-specific + statistics (length, numeric range, date range, true count), table-level + record count, and table-group rollups (profiling score, hygiene count). + """ + + # Apply to any column + NULL_RATIO = "Null Ratio" + DISTINCT_RATIO = "Distinct Ratio" + FILLED_RATIO = "Filled Ratio" + # Apply to the parent table + RECORD_COUNT = "Record Count" + # Apply to the whole table group + PROFILING_SCORE = "Profiling Score" + HYGIENE_COUNT = "Hygiene Count" + # Alpha-only + MIN_LENGTH = "Min Length" + MAX_LENGTH = "Max Length" + AVG_LENGTH = "Avg Length" + # Numeric-only + MIN = "Min" + MAX = "Max" + AVG = "Avg" + STDEV = "Stdev" + # Date-only + MIN_DATE = "Min Date" + MAX_DATE = "Max Date" + # Boolean-only + TRUE_COUNT = "True Count" + + @dataclass class ColumnProfileSummary(EntityMinimal): column_name: str diff --git a/testgen/common/models/profile_result.py b/testgen/common/models/profile_result.py index 0eef47d6..046cb015 100644 --- a/testgen/common/models/profile_result.py +++ b/testgen/common/models/profile_result.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable from datetime import datetime from uuid import UUID, uuid4 @@ -124,3 +125,25 @@ def get_for_column( rows = list(cls.select_where(*clauses, order_by=(desc(cls.profile_run_id),))) return rows[0] if rows else None + + @classmethod + def select_for_runs( + cls, + run_ids: Iterable[UUID], + table_name: str | None = None, + column_name: str | None = None, + ) -> list["ProfileResult"]: + """Fetch profile-results rows for a set of profiling runs in one query. + + Optional ``table_name`` and ``column_name`` filters narrow the result to one + entity (case-sensitive exact match). + """ + run_ids = list(run_ids) + if not run_ids: + return [] + clauses = [cls.profile_run_id.in_(run_ids)] + if table_name is not None: + clauses.append(cls.table_name == table_name) + if column_name is not None: + clauses.append(cls.column_name == column_name) + return list(cls.select_where(*clauses)) diff --git a/testgen/common/models/profiling_run.py b/testgen/common/models/profiling_run.py index 2dc05f5d..c37bd407 100644 --- a/testgen/common/models/profiling_run.py +++ b/testgen/common/models/profiling_run.py @@ -423,3 +423,35 @@ def get_previous(self) -> Self | None: .limit(1) ) return get_current_session().scalar(query) + + @classmethod + def list_recent_complete(cls, table_groups_id: UUID, limit: int) -> list[Self]: + """Return the most recent completed profiling runs for a table group, newest first.""" + query = ( + select(cls) + .join(JobExecution, cls.job_execution_id == JobExecution.id) + .where( + cls.table_groups_id == table_groups_id, + JobExecution.status == JobStatus.COMPLETED, + ) + .order_by(desc(JobExecution.started_at)) + .limit(limit) + ) + return list(get_current_session().scalars(query)) + + @classmethod + def count_confirmed_hygiene_issues(cls, run_ids: list[UUID]) -> dict[UUID, int]: + """Count confirmed hygiene issues per profiling run. Missing runs default to zero.""" + if not run_ids: + return {} + from testgen.common.models.hygiene_issue import HygieneIssue + + query = ( + select(HygieneIssue.profile_run_id, func.count()) + .where( + HygieneIssue.profile_run_id.in_(run_ids), + func.coalesce(HygieneIssue.disposition, "Confirmed") == "Confirmed", + ) + .group_by(HygieneIssue.profile_run_id) + ) + return {row[0]: row[1] for row in get_current_session().execute(query)} diff --git a/testgen/mcp/server.py b/testgen/mcp/server.py index 77fa0bc0..7b573679 100644 --- a/testgen/mcp/server.py +++ b/testgen/mcp/server.py @@ -151,6 +151,11 @@ def build_mcp_server( search_hygiene_issues, update_hygiene_issue, ) + from testgen.mcp.tools.profile_history import ( + compare_profiling_runs, + get_profiling_trends, + get_schema_history, + ) from testgen.mcp.tools.profiling import ( get_column_frequent_values, get_column_patterns, @@ -243,6 +248,9 @@ def safe_prompt(fn): safe_tool(get_column_frequent_values) safe_tool(get_column_patterns) safe_tool(search_columns) + safe_tool(compare_profiling_runs) + safe_tool(get_profiling_trends) + safe_tool(get_schema_history) safe_tool(run_tests) safe_tool(run_profiling) safe_tool(cancel_test_run) diff --git a/testgen/mcp/tools/common.py b/testgen/mcp/tools/common.py index 054c508f..6017e449 100644 --- a/testgen/mcp/tools/common.py +++ b/testgen/mcp/tools/common.py @@ -11,6 +11,7 @@ GENERAL_TYPE_TO_CODE, ColumnOrderBy, GeneralType, + ProfileMetric, SuggestedDataType, ) from testgen.common.models.hygiene_issue import HygieneIssueType @@ -245,6 +246,23 @@ def parse_column_order_by(value: str) -> ColumnOrderBy: raise MCPUserError(f"Invalid order_by `{value}`. Valid values: {valid}") from err +def parse_profile_metrics(values: list[str]) -> list[ProfileMetric]: + """Validate a list of profile metric names. Empties out with one error listing all invalids.""" + if not values: + raise MCPUserError("`metrics` cannot be empty — name at least one metric to trend.") + parsed: list[ProfileMetric] = [] + invalid: list[str] = [] + for value in values: + try: + parsed.append(ProfileMetric(value)) + except ValueError: + invalid.append(value) + if invalid: + valid = ", ".join(m.value for m in ProfileMetric) + raise MCPUserError(f"Invalid metrics {invalid}. Valid values: {valid}") + return parsed + + # ``pii_flag`` encodes risk as a single-character prefix: ``A`` (High), ``B`` (Moderate), ``C`` (Low). _PII_RISK_LEVEL_TO_CODE: dict[str, str] = {"High": "A", "Moderate": "B", "Low": "C"} diff --git a/testgen/mcp/tools/profile_history.py b/testgen/mcp/tools/profile_history.py new file mode 100644 index 00000000..a4eb7897 --- /dev/null +++ b/testgen/mcp/tools/profile_history.py @@ -0,0 +1,866 @@ +"""MCP tools that operate across multiple profiling runs of a table group. + +- ``compare_profiling_runs`` — diff two runs (metric changes for shared columns + hygiene churn). +- ``get_profiling_trends`` — caller-named metric time-series across recent runs. +- ``get_schema_history`` — per-run structural changes (tables/columns added/dropped/re-typed) + with table record-count deltas. + +Structural enumeration intentionally lives only in ``get_schema_history``; the comparison tool +renders a one-line pointer to it rather than duplicating the per-table churn. +""" +from collections import defaultdict +from collections.abc import Iterable +from datetime import datetime +from typing import NamedTuple +from uuid import UUID + +from sqlalchemy import func + +from testgen.common.models import with_database_session +from testgen.common.models.data_column import ProfileMetric +from testgen.common.models.hygiene_issue import HygieneIssue, HygieneIssueType +from testgen.common.models.profile_result import ProfileResult +from testgen.common.models.profiling_run import ProfilingRun +from testgen.mcp.exceptions import MCPUserError +from testgen.mcp.permissions import mcp_permission +from testgen.mcp.tools.common import ( + DocGroup, + parse_profile_metrics, + resolve_profiling_run, + resolve_table_group, + validate_limit, +) +from testgen.mcp.tools.markdown import MdDoc +from testgen.utils import friendly_score + +_DOC_GROUP = DocGroup.BROWSE_PROFILING + + +# --------------------------------------------------------------------------- +# General-type vocabulary +# --------------------------------------------------------------------------- + +# Single-letter general_type codes (stored on ProfileResult.general_type and +# DataColumnChars.general_type). Mirrors GENERAL_TYPE_TO_CODE values but locally +# named for readability inside this module's scope/type-restriction tables. +_TYPE_ALPHA = "A" +_TYPE_NUMERIC = "N" +_TYPE_DATE = "D" +_TYPE_BOOLEAN = "B" + +_TYPE_LABELS: dict[str, str] = { + _TYPE_ALPHA: "Alpha", + _TYPE_NUMERIC: "Numeric", + _TYPE_DATE: "Date", + _TYPE_BOOLEAN: "Boolean", + "T": "Time", + "X": "Other", +} + + +# --------------------------------------------------------------------------- +# Metric scope + extraction +# --------------------------------------------------------------------------- + +_SCOPE_TABLE_GROUP = "table_group" +_SCOPE_TABLE = "table" +_SCOPE_COLUMN = "column" + +_METRIC_SCOPE: dict[ProfileMetric, str] = { + ProfileMetric.NULL_RATIO: _SCOPE_COLUMN, + ProfileMetric.DISTINCT_RATIO: _SCOPE_COLUMN, + ProfileMetric.FILLED_RATIO: _SCOPE_COLUMN, + ProfileMetric.MIN_LENGTH: _SCOPE_COLUMN, + ProfileMetric.MAX_LENGTH: _SCOPE_COLUMN, + ProfileMetric.AVG_LENGTH: _SCOPE_COLUMN, + ProfileMetric.MIN: _SCOPE_COLUMN, + ProfileMetric.MAX: _SCOPE_COLUMN, + ProfileMetric.AVG: _SCOPE_COLUMN, + ProfileMetric.STDEV: _SCOPE_COLUMN, + ProfileMetric.MIN_DATE: _SCOPE_COLUMN, + ProfileMetric.MAX_DATE: _SCOPE_COLUMN, + ProfileMetric.TRUE_COUNT: _SCOPE_COLUMN, + ProfileMetric.RECORD_COUNT: _SCOPE_TABLE, + ProfileMetric.PROFILING_SCORE: _SCOPE_TABLE_GROUP, + ProfileMetric.HYGIENE_COUNT: _SCOPE_TABLE_GROUP, +} + +# Type-specific metrics only return a value when the column's general_type matches. +_METRIC_TYPE: dict[ProfileMetric, str] = { + ProfileMetric.MIN_LENGTH: _TYPE_ALPHA, + ProfileMetric.MAX_LENGTH: _TYPE_ALPHA, + ProfileMetric.AVG_LENGTH: _TYPE_ALPHA, + ProfileMetric.MIN: _TYPE_NUMERIC, + ProfileMetric.MAX: _TYPE_NUMERIC, + ProfileMetric.AVG: _TYPE_NUMERIC, + ProfileMetric.STDEV: _TYPE_NUMERIC, + ProfileMetric.MIN_DATE: _TYPE_DATE, + ProfileMetric.MAX_DATE: _TYPE_DATE, + ProfileMetric.TRUE_COUNT: _TYPE_BOOLEAN, +} + +# Metrics rendered as percentages. +_PERCENT_METRICS = { + ProfileMetric.NULL_RATIO, + ProfileMetric.DISTINCT_RATIO, + ProfileMetric.FILLED_RATIO, +} + + +def _validate_metric_scope(metrics: list[ProfileMetric], table_name: str | None, column_name: str | None) -> None: + """Reject when any metric needs a deeper scope than the provided arguments offer.""" + needs_column = [m for m in metrics if _METRIC_SCOPE[m] == _SCOPE_COLUMN] + needs_table = [m for m in metrics if _METRIC_SCOPE[m] == _SCOPE_TABLE] + if needs_column and column_name is None: + names = ", ".join(f"`{m.value}`" for m in needs_column) + raise MCPUserError(f"Metrics {names} require both `table_name` and `column_name`.") + if needs_table and table_name is None: + names = ", ".join(f"`{m.value}`" for m in needs_table) + raise MCPUserError(f"Metrics {names} require `table_name`.") + + +def _column_metric_value(metric: ProfileMetric, pr: ProfileResult | None) -> object | None: + """Extract a column-scope metric value from a ProfileResult row. + + Returns ``None`` if the row is missing or the metric doesn't apply to the + column's ``general_type`` (e.g. ``Avg Length`` on a numeric column). + """ + if pr is None: + return None + required_type = _METRIC_TYPE.get(metric) + if required_type is not None and pr.general_type != required_type: + return None + record_ct = pr.record_ct + if metric is ProfileMetric.NULL_RATIO: + return pr.null_value_ct / record_ct if record_ct and pr.null_value_ct is not None else None + if metric is ProfileMetric.DISTINCT_RATIO: + return pr.distinct_value_ct / record_ct if record_ct and pr.distinct_value_ct is not None else None + if metric is ProfileMetric.FILLED_RATIO: + return pr.filled_value_ct / record_ct if record_ct and pr.filled_value_ct is not None else None + if metric is ProfileMetric.RECORD_COUNT: + return pr.record_ct + if metric is ProfileMetric.MIN_LENGTH: + return pr.min_length + if metric is ProfileMetric.MAX_LENGTH: + return pr.max_length + if metric is ProfileMetric.AVG_LENGTH: + return pr.avg_length + if metric is ProfileMetric.MIN: + return pr.min_value + if metric is ProfileMetric.MAX: + return pr.max_value + if metric is ProfileMetric.AVG: + return pr.avg_value + if metric is ProfileMetric.STDEV: + return pr.stdev_value + if metric is ProfileMetric.MIN_DATE: + return pr.min_date + if metric is ProfileMetric.MAX_DATE: + return pr.max_date + if metric is ProfileMetric.TRUE_COUNT: + return pr.boolean_true_ct + return None + + +def _format_metric_value(metric: ProfileMetric, value: object | None) -> str: + if value is None: + return "—" + if metric is ProfileMetric.PROFILING_SCORE and isinstance(value, int | float): + return friendly_score(value) or "—" + if metric in _PERCENT_METRICS and isinstance(value, int | float): + return f"{float(value) * 100:.1f}%" + if isinstance(value, datetime): + return value.date().isoformat() + if isinstance(value, float): + # 6 significant digits with thousands separators preserves precision for + # ratios in the 0.x range (e.g. 5.94821) while keeping wide values readable + # (e.g. 12,345.6). + return f"{value:,.6g}" + if isinstance(value, int): + return f"{value:,}" + return str(value) + + +def _delta_cell(metric: ProfileMetric, baseline: object | None, target: object | None) -> str: + """Render a baseline → target cell. ``B (=)`` when unchanged after formatting. + + Equality is checked on the formatted strings, not the raw values — two timestamps + that render as the same date display as ``(=)`` rather than a no-op ``→``. + """ + baseline_str = _format_metric_value(metric, baseline) + target_str = _format_metric_value(metric, target) + if baseline_str == target_str: + return f"{target_str} (=)" + return f"{baseline_str} → {target_str}" + + +# --------------------------------------------------------------------------- +# Run-state guard +# --------------------------------------------------------------------------- + +_REQUIRED_RUN_STATUS = "Complete" + + +def _require_completed(run: ProfilingRun, label: str) -> None: + if run.status != _REQUIRED_RUN_STATUS: + raise MCPUserError( + f"{label} run is in `{run.status}` state — comparison requires a completed run." + ) + + +# --------------------------------------------------------------------------- +# Compare profiling runs +# --------------------------------------------------------------------------- + + +# Per-general-type metric tables. Excludes the type-display column header so the +# table is uniformly wide; cross-flavor type-display drift is surfaced via footnote. +_METRIC_TABLE_BY_TYPE: dict[str, list[ProfileMetric]] = { + _TYPE_NUMERIC: [ + ProfileMetric.NULL_RATIO, + ProfileMetric.DISTINCT_RATIO, + ProfileMetric.MIN, + ProfileMetric.MAX, + ProfileMetric.AVG, + ProfileMetric.STDEV, + ProfileMetric.RECORD_COUNT, + ], + _TYPE_ALPHA: [ + ProfileMetric.NULL_RATIO, + ProfileMetric.DISTINCT_RATIO, + ProfileMetric.AVG_LENGTH, + ProfileMetric.MIN_LENGTH, + ProfileMetric.MAX_LENGTH, + ProfileMetric.RECORD_COUNT, + ], + _TYPE_DATE: [ + ProfileMetric.NULL_RATIO, + ProfileMetric.MIN_DATE, + ProfileMetric.MAX_DATE, + ProfileMetric.RECORD_COUNT, + ], + _TYPE_BOOLEAN: [ + ProfileMetric.NULL_RATIO, + ProfileMetric.TRUE_COUNT, + ProfileMetric.RECORD_COUNT, + ], +} + +# Categorical attributes rendered only when they change. Keys are user-facing +# field labels; values are ProfileResult attribute names. +_CATEGORICAL_FIELDS: dict[str, str] = { + "Type": "column_type", + "Semantic Type": "functional_data_type", + "PII": "pii_flag", + "Suggested Type": "datatype_suggestion", +} + + +def _pair_results( + rows: Iterable[ProfileResult], target_run_id: UUID, baseline_run_id: UUID, +) -> dict[tuple[str, str, str], dict[str, ProfileResult]]: + """Group profile-results by (schema, table, column) and tag each row as target/baseline.""" + by_key: dict[tuple[str, str, str], dict[str, ProfileResult]] = defaultdict(dict) + for row in rows: + key = (row.schema_name, row.table_name, row.column_name) + if row.profile_run_id == target_run_id: + by_key[key]["target"] = row + elif row.profile_run_id == baseline_run_id: + by_key[key]["baseline"] = row + return by_key + + +@with_database_session +@mcp_permission("catalog") +def compare_profiling_runs( + target_job_execution_id: str, + baseline_job_execution_id: str | None = None, + table_name: str | None = None, + column_name: str | None = None, +) -> str: + """Compare two profiling runs on the same table group and report metric changes for shared columns plus hygiene issue churn. + + When ``baseline_job_execution_id`` is omitted, the baseline defaults to the most recent + completed profiling run on the same table group submitted before the target run. Both + runs must be in `Completed` state. + + Reports only on columns present in both runs. When structural drift exists, the output + notes that fact in one line; the per-table/column structural diff is not enumerated here. + + Args: + target_job_execution_id: UUID of the newer profiling run (the "after" snapshot), + e.g. from `list_profiling_runs`. + baseline_job_execution_id: Optional UUID of the older profiling run (the "before" + snapshot). When omitted, defaults to the previous completed run on the same + table group. + table_name: Optional — restrict the comparison to one table (case-sensitive). + column_name: Optional — restrict the comparison to one column (requires + `table_name` when used in the diff render but accepted independently). + """ + target_run = resolve_profiling_run(target_job_execution_id) + _require_completed(target_run, "Target") + + if baseline_job_execution_id is None: + baseline_run = target_run.get_previous() + if baseline_run is None: + raise MCPUserError( + f"Target run `{target_job_execution_id}` is the first completed profiling run " + "on its table group — pass `baseline_job_execution_id` to compare against." + ) + else: + baseline_run = resolve_profiling_run(baseline_job_execution_id) + _require_completed(baseline_run, "Baseline") + if baseline_run.table_groups_id != target_run.table_groups_id: + raise MCPUserError( + "Both runs must belong to the same table group to be comparable. " + f"Target is in table group `{target_run.table_groups_id}`, " + f"baseline is in table group `{baseline_run.table_groups_id}`." + ) + + rows = ProfileResult.select_for_runs( + run_ids=[target_run.id, baseline_run.id], + table_name=table_name, + column_name=column_name, + ) + paired = _pair_results(rows, target_run.id, baseline_run.id) + + has_structural_changes = any( + "target" not in sides or "baseline" not in sides for sides in paired.values() + ) + shared = {key: sides for key, sides in paired.items() if "target" in sides and "baseline" in sides} + + hygiene_diff = _diff_hygiene_issues( + target_run.id, baseline_run.id, table_name=table_name, column_name=column_name, + ) + + return _render_run_comparison( + target_run=target_run, + baseline_run=baseline_run, + shared=shared, + has_structural_changes=has_structural_changes, + hygiene_diff=hygiene_diff, + ) + + +class _HygieneRow(NamedTuple): + table_name: str + column_name: str + issue_type: str + + +def _diff_hygiene_issues( + target_run_id: UUID, + baseline_run_id: UUID, + table_name: str | None, + column_name: str | None, +) -> dict[str, list[_HygieneRow]]: + """Return ``{"added": [...], "resolved": [...]}`` lists of hygiene-issue rows. + + Matches issues across the two runs by (table, column, type_id) — only confirmed + issues (default disposition) are counted. + """ + clauses = [ + HygieneIssue.profile_run_id.in_([target_run_id, baseline_run_id]), + func.coalesce(HygieneIssue.disposition, "Confirmed") == "Confirmed", + ] + if table_name is not None: + clauses.append(HygieneIssue.table_name == table_name) + if column_name is not None: + clauses.append(HygieneIssue.column_name == column_name) + issues = list(HygieneIssue.select_where(*clauses)) + + type_ids = {issue.type_id for issue in issues} + type_names: dict[str, str] = {} + if type_ids: + type_names = { + t.id: t.name for t in HygieneIssueType.select_where(HygieneIssueType.id.in_(type_ids)) + } + + target_keys: set[tuple[str, str, str]] = set() + baseline_keys: set[tuple[str, str, str]] = set() + for issue in issues: + key = (issue.table_name, issue.column_name, issue.type_id) + if issue.profile_run_id == target_run_id: + target_keys.add(key) + else: + baseline_keys.add(key) + + def _rows(keys: Iterable[tuple[str, str, str]]) -> list[_HygieneRow]: + return sorted( + (_HygieneRow(t, c, type_names.get(tid, tid)) for t, c, tid in keys), + key=lambda r: (r.table_name, r.column_name, r.issue_type), + ) + + return { + "added": _rows(target_keys - baseline_keys), + "resolved": _rows(baseline_keys - target_keys), + } + + +def _categorical_change(label: str, baseline: ProfileResult, target: ProfileResult) -> tuple[str, str] | None: + """Return ``(label, "B → T")`` when a categorical field changed, else ``None``.""" + attr = _CATEGORICAL_FIELDS[label] + baseline_value = getattr(baseline, attr) + target_value = getattr(target, attr) + if baseline_value == target_value: + return None + baseline_display = baseline_value if baseline_value is not None else "—" + target_display = target_value if target_value is not None else "—" + return label, f"{baseline_display} → {target_display}" + + +def _build_metric_rows_for_type( + general_type: str, + shared: dict[tuple[str, str, str], dict[str, ProfileResult]], +) -> tuple[list[str], list[list[str]]]: + """Build (headers, rows) for the metric-change table for one general_type bucket.""" + metrics = _METRIC_TABLE_BY_TYPE[general_type] + headers = ["Table", "Column", *(m.value for m in metrics)] + rows: list[list[str]] = [] + for (_, table, column), sides in sorted(shared.items()): + baseline = sides["baseline"] + target = sides["target"] + # Bucket by target's type. Columns that switched type between runs render here + # under the new type; the old/new type is also surfaced as a categorical change. + if target.general_type != general_type: + continue + # Only render columns that changed in at least one metric in this bucket. + deltas: list[str] = [] + any_changed = False + for metric in metrics: + target_value = _column_metric_value(metric, target) + baseline_value = _column_metric_value(metric, baseline) + if target_value != baseline_value: + any_changed = True + deltas.append(_delta_cell(metric, baseline_value, target_value)) + if any_changed: + rows.append([table, column, *deltas]) + return headers, rows + + +def _categorical_lines( + shared: dict[tuple[str, str, str], dict[str, ProfileResult]], +) -> list[str]: + """Return one bullet per shared column that has at least one categorical change.""" + lines: list[str] = [] + for (_, table, column), sides in sorted(shared.items()): + baseline = sides["baseline"] + target = sides["target"] + changes: list[str] = [] + for label in _CATEGORICAL_FIELDS: + change = _categorical_change(label, baseline, target) + if change is not None: + changes.append(f"{change[0]}: {change[1]}") + if changes: + lines.append(f"`{table}.{column}` — {', '.join(changes)}") + return lines + + +def _render_run_comparison( + target_run: ProfilingRun, + baseline_run: ProfilingRun, + shared: dict[tuple[str, str, str], dict[str, ProfileResult]], + has_structural_changes: bool, + hygiene_diff: dict[str, list[_HygieneRow]], +) -> str: + doc = MdDoc() + doc.heading(1, "Profiling Run Comparison") + doc.table( + ["", "Target", "Baseline"], + [ + ["Profiling Run", + MdDoc.code(str(target_run.job_execution_id)), + MdDoc.code(str(baseline_run.job_execution_id))], + ["Started", target_run.profiling_starttime, baseline_run.profiling_starttime], + ], + ) + + if has_structural_changes: + doc.text( + "_Structural changes also occurred between these runs — " + "call `get_schema_history(table_group_id)` for the per-table/column diff._" + ) + + # Metric tables, one per general_type bucket + rendered_any_metric_table = False + for general_type in (_TYPE_NUMERIC, _TYPE_ALPHA, _TYPE_DATE, _TYPE_BOOLEAN): + headers, rows = _build_metric_rows_for_type(general_type, shared) + if rows: + rendered_any_metric_table = True + doc.heading(2, f"{_TYPE_LABELS[general_type]} columns") + doc.table(headers, rows, code=[0, 1]) + + categorical_lines = _categorical_lines(shared) + if categorical_lines: + doc.heading(2, "Categorical changes") + doc.bullets(categorical_lines) + + added = hygiene_diff["added"] + resolved = hygiene_diff["resolved"] + if added or resolved: + doc.heading(2, "Hygiene issues") + if resolved: + doc.heading(3, f"Resolved ({len(resolved)})") + doc.table( + ["Table", "Column", "Issue type"], + [[r.table_name, r.column_name, r.issue_type] for r in resolved], + code=[0, 1], + ) + if added: + doc.heading(3, f"Added ({len(added)})") + doc.table( + ["Table", "Column", "Issue type"], + [[r.table_name, r.column_name, r.issue_type] for r in added], + code=[0, 1], + ) + + if not (rendered_any_metric_table or categorical_lines or added or resolved or has_structural_changes): + doc.text("_No changes between target and baseline._") + + return doc.render() + + +# --------------------------------------------------------------------------- +# Profiling trends +# --------------------------------------------------------------------------- + + +@with_database_session +@mcp_permission("catalog") +def get_profiling_trends( + table_group_id: str, + metrics: list[str], + table_name: str | None = None, + column_name: str | None = None, + limit: int = 10, +) -> str: + """Show a time series of caller-named profiling metrics across recent completed runs of a table group. + + Metric scope rules: + - Column-level metrics (e.g. `Null Ratio`, `Avg Length`, `Min`) require both + `table_name` and `column_name`. + - `Record Count` is table-level and requires `table_name`. + - `Profiling Score` and `Hygiene Count` are table-group-level and accept any scope. + - Type-specific metrics return `—` for runs where the column's general type + didn't match (e.g. `Min` on a column that was Alpha in an earlier run). + + Args: + table_group_id: UUID of the table group, e.g. from `get_data_inventory`. + metrics: One or more metric names. Accepted values: `Null Ratio`, `Distinct Ratio`, + `Filled Ratio`, `Record Count`, `Profiling Score`, `Hygiene Count`, + `Min Length`, `Max Length`, `Avg Length`, `Min`, `Max`, `Avg`, `Stdev`, + `Min Date`, `Max Date`, `True Count`. + table_name: Optional — restrict to one table (case-sensitive). + column_name: Optional — restrict to one column (case-sensitive); requires + `table_name`. + limit: Number of most-recent completed runs to include (default 10, max 50). + """ + validate_limit(limit, 50) + if column_name is not None and table_name is None: + raise MCPUserError("`column_name` requires `table_name`.") + + tg = resolve_table_group(table_group_id) + metric_enums = parse_profile_metrics(metrics) + _validate_metric_scope(metric_enums, table_name, column_name) + + runs = ProfilingRun.list_recent_complete(tg.id, limit=limit) + if not runs: + return f"No completed profiling runs found for table group `{table_group_id}`." + + run_ids = [r.id for r in runs] + needs_profile_rows = any(_METRIC_SCOPE[m] in (_SCOPE_COLUMN, _SCOPE_TABLE) for m in metric_enums) + profile_by_run: dict[UUID, ProfileResult] = {} + if needs_profile_rows: + rows = ProfileResult.select_for_runs( + run_ids=run_ids, table_name=table_name, column_name=column_name, + ) + if column_name is not None: + profile_by_run = {row.profile_run_id: row for row in rows} + else: + # Table-only scope: there may be many ProfileResult rows per (run, table). + # All carry the same record_ct (table-level); take any. + for row in rows: + profile_by_run.setdefault(row.profile_run_id, row) + + hygiene_counts: dict[UUID, int] = {} + if ProfileMetric.HYGIENE_COUNT in metric_enums: + hygiene_counts = ProfilingRun.count_confirmed_hygiene_issues(run_ids) + + # Bound the entity's presence in the window. `first_seen_run` is the oldest run + # with a profile row; `last_seen_run` is the newest. When either differs from the + # window extreme on its side, a one-line note explains the leading/trailing `—` + # cells in the rendered trend table. + first_seen_run: ProfilingRun | None = None + last_seen_run: ProfilingRun | None = None + if needs_profile_rows: + for run in reversed(runs): + if run.id in profile_by_run: + first_seen_run = run + break + for run in runs: + if run.id in profile_by_run: + last_seen_run = run + break + + return _render_trends( + tg_name=tg.table_groups_name, + runs=runs, + metrics=metric_enums, + profile_by_run=profile_by_run, + hygiene_counts=hygiene_counts, + table_name=table_name, + column_name=column_name, + first_seen_run=first_seen_run, + last_seen_run=last_seen_run, + needs_profile_rows=needs_profile_rows, + ) + + +def _entity_label(table_name: str | None, column_name: str | None) -> str: + if column_name is not None: + return f"`{table_name}.{column_name}`" + if table_name is not None: + return f"`{table_name}`" + return "" + + +def _trend_cell( + metric: ProfileMetric, + run: ProfilingRun, + profile_by_run: dict[UUID, ProfileResult], + hygiene_counts: dict[UUID, int], +) -> str: + if metric is ProfileMetric.PROFILING_SCORE: + return _format_metric_value(metric, run.dq_score_profiling) + if metric is ProfileMetric.HYGIENE_COUNT: + return _format_metric_value(metric, hygiene_counts.get(run.id, 0)) + pr = profile_by_run.get(run.id) + return _format_metric_value(metric, _column_metric_value(metric, pr)) + + +def _render_trends( + tg_name: str, + runs: list[ProfilingRun], + metrics: list[ProfileMetric], + profile_by_run: dict[UUID, ProfileResult], + hygiene_counts: dict[UUID, int], + table_name: str | None, + column_name: str | None, + first_seen_run: ProfilingRun | None, + last_seen_run: ProfilingRun | None, + needs_profile_rows: bool, +) -> str: + doc = MdDoc() + entity = _entity_label(table_name, column_name) + title = f"Profiling trends for {entity} in `{tg_name}`" if entity else f"Profiling trends for `{tg_name}`" + doc.heading(1, title) + doc.field("Runs included", len(runs)) + doc.field("Oldest run", runs[-1].profiling_starttime) + doc.field("Newest run", runs[0].profiling_starttime) + + if needs_profile_rows and first_seen_run is None: + doc.text( + f"_{entity} not present in any of the last {len(runs)} runs — nothing to trend._" + ) + return doc.render() + + if ( + needs_profile_rows + and first_seen_run is not None + and first_seen_run.id != runs[-1].id + ): + doc.text( + f"_{entity} first appears in the run started " + f"{_format_run_label(first_seen_run)}._" + ) + if ( + needs_profile_rows + and last_seen_run is not None + and last_seen_run.id != runs[0].id + ): + doc.text( + f"_{entity} last appears in the run started " + f"{_format_run_label(last_seen_run)}._" + ) + + # Newest-first columns + headers = ["Metric", *(_format_run_label(run) for run in runs)] + rows: list[list[str]] = [] + for metric in metrics: + row = [metric.value] + for run in runs: + row.append(_trend_cell(metric, run, profile_by_run, hygiene_counts)) + rows.append(row) + doc.table(headers, rows) + + return doc.render() + + +# --------------------------------------------------------------------------- +# Schema history +# --------------------------------------------------------------------------- + + +class _TableSnapshot(NamedTuple): + columns: dict[str, "_ColumnSnapshot"] + record_ct: int | None + + +class _ColumnSnapshot(NamedTuple): + column_type: str | None + general_type: str | None + db_data_type: str | None + + +def _build_run_snapshots(rows: Iterable[ProfileResult]) -> dict[UUID, dict[tuple[str, str], _TableSnapshot]]: + """Reduce per-(run, table) profile rows to a column-snapshot map.""" + accumulator: dict[UUID, dict[tuple[str, str], dict[str, _ColumnSnapshot]]] = defaultdict(lambda: defaultdict(dict)) + record_ct: dict[UUID, dict[tuple[str, str], int | None]] = defaultdict(dict) + for row in rows: + run_id = row.profile_run_id + table_key = (row.schema_name, row.table_name) + accumulator[run_id][table_key][row.column_name] = _ColumnSnapshot( + column_type=row.column_type, + general_type=row.general_type, + db_data_type=row.db_data_type, + ) + # All rows in a (run, table) carry the same record_ct; first one wins. + record_ct[run_id].setdefault(table_key, row.record_ct) + + out: dict[UUID, dict[tuple[str, str], _TableSnapshot]] = {} + for run_id, table_columns in accumulator.items(): + out[run_id] = { + tk: _TableSnapshot(columns=cols, record_ct=record_ct[run_id].get(tk)) + for tk, cols in table_columns.items() + } + return out + + +@with_database_session +@mcp_permission("catalog") +def get_schema_history(table_group_id: str, limit: int = 10) -> str: + """Show a per-run timeline of structural changes across recent profiling runs — tables and columns added or dropped, type changes, and record-count deltas per table. + + Args: + table_group_id: UUID of the table group, e.g. from `get_data_inventory`. + limit: Number of recent runs to render deltas for (default 10, max 20). One + additional anchor run is pulled when available so the oldest in-window + run has a baseline to diff against. + """ + validate_limit(limit, 20) + tg = resolve_table_group(table_group_id) + + runs = ProfilingRun.list_recent_complete(tg.id, limit=limit + 1) + if len(runs) < 2: + if not runs: + return f"No completed profiling runs found for table group `{tg.table_groups_name}`." + return ( + f"Only one completed profiling run exists for table group `{tg.table_groups_name}` — " + "at least two are needed to render a history." + ) + + run_ids = [r.id for r in runs] + rows = ProfileResult.select_for_runs(run_ids=run_ids) + snapshots = _build_run_snapshots(rows) + + return _render_schema_history(tg.table_groups_name, runs, snapshots) + + +def _render_schema_history( + tg_name: str, + runs: list[ProfilingRun], + snapshots: dict[UUID, dict[tuple[str, str], _TableSnapshot]], +) -> str: + doc = MdDoc() + doc.heading(1, f"Schema history for `{tg_name}`") + doc.field("Runs analyzed", len(runs) - 1) + doc.field("Window", f"{_format_run_label(runs[-1])} → {_format_run_label(runs[0])}") + + # Iterate newest → oldest, pairing each run with its predecessor. + for index in range(len(runs) - 1): + target = runs[index] + baseline = runs[index + 1] + section_lines = _format_schema_delta( + target_snap=snapshots.get(target.id, {}), + baseline_snap=snapshots.get(baseline.id, {}), + ) + doc.heading(2, f"Run started {_format_run_label(target)}") + doc.field("Profiling Run", target.job_execution_id, code=True) + if section_lines: + doc.bullets(section_lines) + else: + doc.text("_No structural change since previous run._") + + return doc.render() + + +def _format_run_label(run: ProfilingRun) -> str: + """Format a run's start time as ``YYYY-MM-DD HH:MM`` — short enough for column + headers, precise enough to disambiguate same-day runs.""" + return run.profiling_starttime.strftime("%Y-%m-%d %H:%M") + + +def _format_table_key(key: tuple[str, str]) -> str: + schema, table = key + return f"`{schema}.{table}`" if schema else f"`{table}`" + + +def _format_schema_delta( + target_snap: dict[tuple[str, str], _TableSnapshot], + baseline_snap: dict[tuple[str, str], _TableSnapshot], +) -> list[str]: + lines: list[str] = [] + target_tables = set(target_snap) + baseline_tables = set(baseline_snap) + + added_tables = sorted(target_tables - baseline_tables) + for key in added_tables: + col_ct = len(target_snap[key].columns) + lines.append(f"Table added: {_format_table_key(key)} ({col_ct} columns)") + + dropped_tables = sorted(baseline_tables - target_tables) + for key in dropped_tables: + col_ct = len(baseline_snap[key].columns) + lines.append(f"Table dropped: {_format_table_key(key)} ({col_ct} columns)") + + for key in sorted(target_tables & baseline_tables): + target_table = target_snap[key] + baseline_table = baseline_snap[key] + column_changes = _format_column_delta(target_table.columns, baseline_table.columns) + record_delta = _format_record_delta(target_table.record_ct, baseline_table.record_ct) + for change in column_changes: + lines.append(f"{_format_table_key(key)}: {change}") + if record_delta is not None: + lines.append(f"{_format_table_key(key)}: Record count {record_delta}") + return lines + + +def _format_column_delta( + target_cols: dict[str, _ColumnSnapshot], + baseline_cols: dict[str, _ColumnSnapshot], +) -> list[str]: + out: list[str] = [] + target_names = set(target_cols) + baseline_names = set(baseline_cols) + for name in sorted(target_names - baseline_names): + snap = target_cols[name] + out.append(f"column `{name}` added ({snap.column_type or snap.db_data_type or '—'})") + for name in sorted(baseline_names - target_names): + snap = baseline_cols[name] + out.append(f"column `{name}` dropped (was {snap.column_type or snap.db_data_type or '—'})") + for name in sorted(target_names & baseline_names): + target_col = target_cols[name] + baseline_col = baseline_cols[name] + if target_col.column_type != baseline_col.column_type and target_col.column_type and baseline_col.column_type: + out.append( + f"column `{name}` retyped: {baseline_col.column_type} → {target_col.column_type}" + ) + return out + + +def _format_record_delta(target_ct: int | None, baseline_ct: int | None) -> str | None: + if target_ct is None or baseline_ct is None: + return None + if target_ct == baseline_ct: + return None + return f"{baseline_ct:,} → {target_ct:,}" diff --git a/tests/unit/mcp/test_tools_profile_history.py b/tests/unit/mcp/test_tools_profile_history.py new file mode 100644 index 00000000..f40a5187 --- /dev/null +++ b/tests/unit/mcp/test_tools_profile_history.py @@ -0,0 +1,570 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from testgen.common.models.data_column import ProfileMetric +from testgen.mcp.exceptions import MCPUserError +from testgen.mcp.tools.profile_history import ( + _column_metric_value, + _delta_cell, + _format_metric_value, + _validate_metric_scope, + compare_profiling_runs, + get_profiling_trends, + get_schema_history, +) + +# ---------------------------------------------------------------------- +# Helpers +# ---------------------------------------------------------------------- + + +def _profile_row( + run_id=None, + table_name="orders", + column_name="customer_email", + general_type="A", + schema_name="demo", + record_ct=1000, + null_value_ct=50, + distinct_value_ct=900, + filled_value_ct=10, + column_type="varchar(200)", + db_data_type="varchar", + functional_data_type="Person Email", + pii_flag=None, + datatype_suggestion=None, + avg_length=18.0, + min_length=5, + max_length=40, + min_text=None, + max_text=None, + min_value=None, + max_value=None, + avg_value=None, + stdev_value=None, + min_date=None, + max_date=None, + boolean_true_ct=None, +): + row = MagicMock() + row.profile_run_id = run_id or uuid4() + row.schema_name = schema_name + row.table_name = table_name + row.column_name = column_name + row.general_type = general_type + row.column_type = column_type + row.db_data_type = db_data_type + row.functional_data_type = functional_data_type + row.pii_flag = pii_flag + row.datatype_suggestion = datatype_suggestion + row.record_ct = record_ct + row.null_value_ct = null_value_ct + row.distinct_value_ct = distinct_value_ct + row.filled_value_ct = filled_value_ct + row.avg_length = avg_length + row.min_length = min_length + row.max_length = max_length + row.min_text = min_text + row.max_text = max_text + row.min_value = min_value + row.max_value = max_value + row.avg_value = avg_value + row.stdev_value = stdev_value + row.min_date = min_date + row.max_date = max_date + row.boolean_true_ct = boolean_true_ct + return row + + +def _profiling_run( + id_=None, + job_execution_id=None, + table_groups_id=None, + status="Complete", + profiling_starttime=None, + dq_score_profiling=0.92, + table_groups_name="Demo Sales", +): + run = MagicMock() + run.id = id_ or uuid4() + run.job_execution_id = job_execution_id or uuid4() + run.table_groups_id = table_groups_id or uuid4() + run.status = status + run.profiling_starttime = profiling_starttime or datetime(2026, 5, 10, 12, 0) + run.dq_score_profiling = dq_score_profiling + run.table_groups_name = table_groups_name + return run + + +def _table_group(tg_id=None, project_code="demo", name="Demo Sales"): + tg = MagicMock() + tg.id = tg_id or uuid4() + tg.project_code = project_code + tg.table_groups_name = name + return tg + + +# ---------------------------------------------------------------------- +# _column_metric_value +# ---------------------------------------------------------------------- + + +def test_column_metric_value_ratios(): + row = _profile_row(record_ct=1000, null_value_ct=250, distinct_value_ct=900, filled_value_ct=100) + assert _column_metric_value(ProfileMetric.NULL_RATIO, row) == 0.25 + assert _column_metric_value(ProfileMetric.DISTINCT_RATIO, row) == 0.9 + assert _column_metric_value(ProfileMetric.FILLED_RATIO, row) == 0.1 + + +def test_column_metric_value_record_count(): + row = _profile_row(record_ct=1234) + assert _column_metric_value(ProfileMetric.RECORD_COUNT, row) == 1234 + + +def test_column_metric_value_zero_record_ct_returns_none(): + row = _profile_row(record_ct=0, null_value_ct=0, distinct_value_ct=0) + assert _column_metric_value(ProfileMetric.NULL_RATIO, row) is None + assert _column_metric_value(ProfileMetric.DISTINCT_RATIO, row) is None + + +def test_column_metric_value_missing_row_returns_none(): + assert _column_metric_value(ProfileMetric.NULL_RATIO, None) is None + assert _column_metric_value(ProfileMetric.RECORD_COUNT, None) is None + + +def test_column_metric_value_type_restriction(): + numeric_row = _profile_row(general_type="N", avg_value=5.5, avg_length=None) + # Avg Length only applies to Alpha columns + assert _column_metric_value(ProfileMetric.AVG_LENGTH, numeric_row) is None + assert _column_metric_value(ProfileMetric.AVG, numeric_row) == 5.5 + + alpha_row = _profile_row(general_type="A", avg_length=18.0, avg_value=None) + assert _column_metric_value(ProfileMetric.AVG_LENGTH, alpha_row) == 18.0 + assert _column_metric_value(ProfileMetric.AVG, alpha_row) is None + + +def test_column_metric_value_date_min_max(): + row = _profile_row( + general_type="D", + min_date=datetime(2024, 1, 3), + max_date=datetime(2026, 5, 10), + ) + assert _column_metric_value(ProfileMetric.MIN_DATE, row) == datetime(2024, 1, 3) + assert _column_metric_value(ProfileMetric.MAX_DATE, row) == datetime(2026, 5, 10) + + +def test_column_metric_value_boolean_true_count(): + row = _profile_row(general_type="B", boolean_true_ct=42) + assert _column_metric_value(ProfileMetric.TRUE_COUNT, row) == 42 + + +# ---------------------------------------------------------------------- +# _format_metric_value +# ---------------------------------------------------------------------- + + +def test_format_metric_value_percent(): + assert _format_metric_value(ProfileMetric.NULL_RATIO, 0.25) == "25.0%" + assert _format_metric_value(ProfileMetric.DISTINCT_RATIO, 0.9) == "90.0%" + + +def test_format_metric_value_profiling_score_uses_friendly_score(): + # Profiling Score follows the codebase-wide friendly_score convention: + # value (0-1) scaled to 0-100 with no '%' suffix. + assert _format_metric_value(ProfileMetric.PROFILING_SCORE, 0.92) == "92.0" + assert _format_metric_value(ProfileMetric.PROFILING_SCORE, 1.0) == "100" + + +def test_format_metric_value_record_count_thousands_separator(): + assert _format_metric_value(ProfileMetric.RECORD_COUNT, 12345) == "12,345" + + +def test_format_metric_value_datetime_date_only(): + assert _format_metric_value(ProfileMetric.MIN_DATE, datetime(2024, 1, 3, 14, 30)) == "2024-01-03" + + +def test_format_metric_value_none(): + assert _format_metric_value(ProfileMetric.NULL_RATIO, None) == "—" + + +# ---------------------------------------------------------------------- +# _delta_cell +# ---------------------------------------------------------------------- + + +def test_delta_cell_unchanged(): + assert _delta_cell(ProfileMetric.NULL_RATIO, 0.25, 0.25) == "25.0% (=)" + + +def test_delta_cell_changed(): + assert _delta_cell(ProfileMetric.NULL_RATIO, 0.30, 0.05) == "30.0% → 5.0%" + + +def test_delta_cell_dates_render_as_dates_only(): + # Different timestamps on the same date format identically -> rendered as (=) + a = datetime(2024, 1, 3, 6, 0) + b = datetime(2024, 1, 3, 18, 0) + assert _delta_cell(ProfileMetric.MIN_DATE, a, b) == "2024-01-03 (=)" + + +def test_delta_cell_none_baseline(): + assert _delta_cell(ProfileMetric.RECORD_COUNT, None, 1000) == "— → 1,000" + + +# ---------------------------------------------------------------------- +# _validate_metric_scope +# ---------------------------------------------------------------------- + + +def test_validate_metric_scope_column_metric_requires_column(): + with pytest.raises(MCPUserError, match="require both `table_name` and `column_name`"): + _validate_metric_scope([ProfileMetric.NULL_RATIO], table_name="orders", column_name=None) + + +def test_validate_metric_scope_table_metric_requires_table(): + with pytest.raises(MCPUserError, match="require `table_name`"): + _validate_metric_scope([ProfileMetric.RECORD_COUNT], table_name=None, column_name=None) + + +def test_validate_metric_scope_tg_metric_accepts_any_scope(): + # No exception when no scope args provided + _validate_metric_scope([ProfileMetric.PROFILING_SCORE], table_name=None, column_name=None) + _validate_metric_scope([ProfileMetric.HYGIENE_COUNT], table_name=None, column_name=None) + + +def test_validate_metric_scope_mixed_scopes_all_satisfied(): + _validate_metric_scope( + [ProfileMetric.NULL_RATIO, ProfileMetric.RECORD_COUNT, ProfileMetric.PROFILING_SCORE], + table_name="orders", + column_name="email", + ) + + +# ---------------------------------------------------------------------- +# compare_profiling_runs — flow tests +# ---------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.profile_history.HygieneIssue") +@patch("testgen.mcp.tools.profile_history.HygieneIssueType") +@patch("testgen.mcp.tools.profile_history.ProfileResult") +@patch("testgen.mcp.tools.profile_history.resolve_profiling_run") +def test_compare_profiling_runs_auto_baseline( + mock_resolve, mock_pr, mock_iss_type, mock_iss, db_session_mock, +): + tg_id = uuid4() + target_run = _profiling_run(table_groups_id=tg_id, profiling_starttime=datetime(2026, 5, 13)) + baseline_run = _profiling_run(table_groups_id=tg_id, profiling_starttime=datetime(2026, 5, 10)) + target_run.get_previous.return_value = baseline_run + mock_resolve.return_value = target_run + + target_row = _profile_row(run_id=target_run.id, null_value_ct=50) + baseline_row = _profile_row(run_id=baseline_run.id, null_value_ct=300) + mock_pr.select_for_runs.return_value = [target_row, baseline_row] + mock_iss.select_where.return_value = [] + mock_iss_type.select_where.return_value = [] + + result = compare_profiling_runs(str(target_run.job_execution_id)) + + assert "Profiling Run Comparison" in result + assert "Target" in result and "Baseline" in result + assert "Profiling Run" in result and "Started" in result + target_run.get_previous.assert_called_once() + + +@patch("testgen.mcp.tools.profile_history.resolve_profiling_run") +def test_compare_profiling_runs_rejects_non_completed_target(mock_resolve, db_session_mock): + target_run = _profiling_run(status="Running") + mock_resolve.return_value = target_run + + with pytest.raises(MCPUserError, match="Target run is in `Running` state"): + compare_profiling_runs(str(target_run.job_execution_id)) + + +@patch("testgen.mcp.tools.profile_history.resolve_profiling_run") +def test_compare_profiling_runs_rejects_cancelled_target(mock_resolve, db_session_mock): + target_run = _profiling_run(status="Cancelled") + mock_resolve.return_value = target_run + + with pytest.raises(MCPUserError, match="`Cancelled`"): + compare_profiling_runs(str(target_run.job_execution_id)) + + +@patch("testgen.mcp.tools.profile_history.resolve_profiling_run") +def test_compare_profiling_runs_rejects_cross_table_group(mock_resolve, db_session_mock): + target_run = _profiling_run(table_groups_id=uuid4()) + baseline_run = _profiling_run(table_groups_id=uuid4()) + mock_resolve.side_effect = [target_run, baseline_run] + + with pytest.raises(MCPUserError, match="same table group"): + compare_profiling_runs( + str(target_run.job_execution_id), + str(baseline_run.job_execution_id), + ) + + +@patch("testgen.mcp.tools.profile_history.resolve_profiling_run") +def test_compare_profiling_runs_auto_baseline_first_run(mock_resolve, db_session_mock): + target_run = _profiling_run() + target_run.get_previous.return_value = None + mock_resolve.return_value = target_run + + with pytest.raises(MCPUserError, match="first completed profiling run"): + compare_profiling_runs(str(target_run.job_execution_id)) + + +@patch("testgen.mcp.tools.profile_history.HygieneIssue") +@patch("testgen.mcp.tools.profile_history.HygieneIssueType") +@patch("testgen.mcp.tools.profile_history.ProfileResult") +@patch("testgen.mcp.tools.profile_history.resolve_profiling_run") +def test_compare_profiling_runs_identical_runs_renders_no_changes( + mock_resolve, mock_pr, mock_iss_type, mock_iss, db_session_mock, +): + tg_id = uuid4() + target_run = _profiling_run(table_groups_id=tg_id) + baseline_run = _profiling_run(table_groups_id=tg_id, profiling_starttime=datetime(2026, 5, 1)) + target_run.get_previous.return_value = baseline_run + mock_resolve.return_value = target_run + + target_row = _profile_row(run_id=target_run.id) + baseline_row = _profile_row(run_id=baseline_run.id) # same values + mock_pr.select_for_runs.return_value = [target_row, baseline_row] + mock_iss.select_where.return_value = [] + mock_iss_type.select_where.return_value = [] + + result = compare_profiling_runs(str(target_run.job_execution_id)) + + assert "No changes between target and baseline" in result + + +# ---------------------------------------------------------------------- +# get_profiling_trends +# ---------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.profile_history.ProfilingRun") +@patch("testgen.mcp.tools.profile_history.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_profiling_trends_happy_path(mock_tg_cls, mock_pr, mock_pr_cls, db_session_mock): + tg = _table_group() + mock_tg_cls.get.return_value = tg + + run_old = _profiling_run(profiling_starttime=datetime(2026, 5, 1)) + run_new = _profiling_run(profiling_starttime=datetime(2026, 5, 13)) + mock_pr_cls.list_recent_complete.return_value = [run_new, run_old] + mock_pr_cls.count_confirmed_hygiene_issues.return_value = {} + + rows = [ + _profile_row(run_id=run_old.id, null_value_ct=300), + _profile_row(run_id=run_new.id, null_value_ct=50), + ] + mock_pr.select_for_runs.return_value = rows + + result = get_profiling_trends( + str(tg.id), + metrics=["Null Ratio", "Distinct Ratio"], + table_name="orders", + column_name="customer_email", + ) + + assert "Profiling trends" in result + assert "Null Ratio" in result + assert "Distinct Ratio" in result + assert "2026-05-13" in result and "2026-05-01" in result + + +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_profiling_trends_invalid_metric(mock_tg_cls, db_session_mock): + mock_tg_cls.get.return_value = _table_group() + + with pytest.raises(MCPUserError, match="Invalid metrics"): + get_profiling_trends(str(uuid4()), metrics=["Unknown Metric"]) + + +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_profiling_trends_empty_metrics(mock_tg_cls, db_session_mock): + mock_tg_cls.get.return_value = _table_group() + + with pytest.raises(MCPUserError, match="cannot be empty"): + get_profiling_trends(str(uuid4()), metrics=[]) + + +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_profiling_trends_column_requires_table(mock_tg_cls, db_session_mock): + mock_tg_cls.get.return_value = _table_group() + + with pytest.raises(MCPUserError, match="`column_name` requires `table_name`"): + get_profiling_trends( + str(uuid4()), + metrics=["Null Ratio"], + column_name="email", + ) + + +@patch("testgen.mcp.tools.profile_history.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_profiling_trends_no_runs(mock_tg_cls, mock_pr_cls, db_session_mock): + mock_tg_cls.get.return_value = _table_group() + mock_pr_cls.list_recent_complete.return_value = [] + + # TG-scope metric so we skip the profile-row fetch entirely + result = get_profiling_trends(str(uuid4()), metrics=["Profiling Score"]) + assert "No completed profiling runs" in result + + +@patch("testgen.mcp.tools.profile_history.ProfilingRun") +@patch("testgen.mcp.tools.profile_history.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_profiling_trends_first_appears_note(mock_tg_cls, mock_pr, mock_pr_cls, db_session_mock): + """Entity missing from the oldest run but present in newer runs.""" + mock_tg_cls.get.return_value = _table_group() + run_old = _profiling_run(profiling_starttime=datetime(2026, 5, 1, 9, 0)) + run_mid = _profiling_run(profiling_starttime=datetime(2026, 5, 10, 14, 0)) + run_new = _profiling_run(profiling_starttime=datetime(2026, 5, 13, 10, 0)) + mock_pr_cls.list_recent_complete.return_value = [run_new, run_mid, run_old] + # Only mid and new runs have the column — entity first appears at run_mid. + mock_pr.select_for_runs.return_value = [ + _profile_row(run_id=run_mid.id), + _profile_row(run_id=run_new.id), + ] + + result = get_profiling_trends( + str(uuid4()), + metrics=["Null Ratio"], + table_name="orders", + column_name="customer_email", + ) + assert "first appears in the run started 2026-05-10 14:00" in result + assert "last appears" not in result # present in newest run, no trailing-gap note + + +@patch("testgen.mcp.tools.profile_history.ProfilingRun") +@patch("testgen.mcp.tools.profile_history.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_profiling_trends_last_appears_note(mock_tg_cls, mock_pr, mock_pr_cls, db_session_mock): + """Entity present in early runs but missing from the newest run.""" + mock_tg_cls.get.return_value = _table_group() + run_old = _profiling_run(profiling_starttime=datetime(2026, 5, 1, 9, 0)) + run_mid = _profiling_run(profiling_starttime=datetime(2026, 5, 10, 14, 0)) + run_new = _profiling_run(profiling_starttime=datetime(2026, 5, 13, 10, 0)) + mock_pr_cls.list_recent_complete.return_value = [run_new, run_mid, run_old] + # Only old and mid runs have the column — entity last appears at run_mid. + mock_pr.select_for_runs.return_value = [ + _profile_row(run_id=run_old.id), + _profile_row(run_id=run_mid.id), + ] + + result = get_profiling_trends( + str(uuid4()), + metrics=["Null Ratio"], + table_name="orders", + column_name="legacy_id", + ) + assert "last appears in the run started 2026-05-10 14:00" in result + assert "first appears" not in result # present in oldest run, no leading-gap note + + +@patch("testgen.mcp.tools.profile_history.ProfilingRun") +@patch("testgen.mcp.tools.profile_history.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_profiling_trends_both_notes(mock_tg_cls, mock_pr, mock_pr_cls, db_session_mock): + """Entity has a bounded lifetime — missing on both ends of the window.""" + mock_tg_cls.get.return_value = _table_group() + run_oldest = _profiling_run(profiling_starttime=datetime(2026, 5, 9, 9, 0)) + run_first = _profiling_run(profiling_starttime=datetime(2026, 5, 10, 14, 0)) + run_last = _profiling_run(profiling_starttime=datetime(2026, 5, 12, 22, 0)) + run_newest = _profiling_run(profiling_starttime=datetime(2026, 5, 13, 10, 0)) + mock_pr_cls.list_recent_complete.return_value = [run_newest, run_last, run_first, run_oldest] + # Only the middle two runs carry the column. + mock_pr.select_for_runs.return_value = [ + _profile_row(run_id=run_first.id), + _profile_row(run_id=run_last.id), + ] + + result = get_profiling_trends( + str(uuid4()), + metrics=["Null Ratio"], + table_name="orders", + column_name="customer_email_v2", + ) + assert "first appears in the run started 2026-05-10 14:00" in result + assert "last appears in the run started 2026-05-12 22:00" in result + + +@patch("testgen.mcp.tools.profile_history.ProfilingRun") +@patch("testgen.mcp.tools.profile_history.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_profiling_trends_no_notes_when_present_throughout(mock_tg_cls, mock_pr, mock_pr_cls, db_session_mock): + """Entity present in every run — no first/last-appears noise.""" + mock_tg_cls.get.return_value = _table_group() + run_old = _profiling_run(profiling_starttime=datetime(2026, 5, 1, 9, 0)) + run_new = _profiling_run(profiling_starttime=datetime(2026, 5, 13, 10, 0)) + mock_pr_cls.list_recent_complete.return_value = [run_new, run_old] + mock_pr.select_for_runs.return_value = [ + _profile_row(run_id=run_old.id), + _profile_row(run_id=run_new.id), + ] + + result = get_profiling_trends( + str(uuid4()), + metrics=["Null Ratio"], + table_name="orders", + column_name="customer_id", + ) + assert "first appears" not in result + assert "last appears" not in result + + +# ---------------------------------------------------------------------- +# get_schema_history +# ---------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.profile_history.ProfilingRun") +@patch("testgen.mcp.tools.profile_history.ProfileResult") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_schema_history_happy_path(mock_tg_cls, mock_pr, mock_pr_cls, db_session_mock): + tg = _table_group() + mock_tg_cls.get.return_value = tg + + run_old = _profiling_run(profiling_starttime=datetime(2026, 5, 1)) + run_new = _profiling_run(profiling_starttime=datetime(2026, 5, 13)) + mock_pr_cls.list_recent_complete.return_value = [run_new, run_old] + + rows = [ + _profile_row(run_id=run_old.id, table_name="orders", column_name="id", general_type="N", record_ct=900), + _profile_row(run_id=run_old.id, table_name="orders", column_name="email", general_type="A", record_ct=900), + _profile_row(run_id=run_new.id, table_name="orders", column_name="id", general_type="N", record_ct=1000), + _profile_row(run_id=run_new.id, table_name="orders", column_name="email", general_type="A", record_ct=1000), + _profile_row(run_id=run_new.id, table_name="orders", column_name="phone", general_type="A", record_ct=1000), + ] + mock_pr.select_for_runs.return_value = rows + + result = get_schema_history(str(tg.id)) + + assert "Schema history" in result + assert "phone" in result # newly added column + assert "Record count" in result # 900 → 1,000 delta + + +@patch("testgen.mcp.tools.profile_history.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_schema_history_single_run_short_circuits(mock_tg_cls, mock_pr_cls, db_session_mock): + mock_tg_cls.get.return_value = _table_group() + mock_pr_cls.list_recent_complete.return_value = [_profiling_run()] + + result = get_schema_history(str(uuid4())) + assert "at least two are needed" in result + + +@patch("testgen.mcp.tools.profile_history.ProfilingRun") +@patch("testgen.mcp.tools.common.TableGroup") +def test_get_schema_history_no_runs(mock_tg_cls, mock_pr_cls, db_session_mock): + mock_tg_cls.get.return_value = _table_group() + mock_pr_cls.list_recent_complete.return_value = [] + + result = get_schema_history(str(uuid4())) + assert "No completed profiling runs" in result From 49c3d2ced2ab023f31186b80ce402fe1d4431772 Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Sun, 17 May 2026 12:15:13 -0400 Subject: [PATCH 25/58] refactor(mcp): apply TG-1068 review feedback --- testgen/common/models/data_column.py | 26 +++++---- testgen/common/models/profiling_run.py | 4 +- testgen/mcp/tools/profile_history.py | 49 ++++++++++------- testgen/mcp/tools/profiling.py | 4 +- testgen/mcp/tools/reference.py | 4 +- tests/unit/mcp/test_tools_profile_history.py | 58 ++++++++++++++------ 6 files changed, 91 insertions(+), 54 deletions(-) diff --git a/testgen/common/models/data_column.py b/testgen/common/models/data_column.py index e0e5d8ec..cee7d088 100644 --- a/testgen/common/models/data_column.py +++ b/testgen/common/models/data_column.py @@ -96,7 +96,9 @@ class ProfileMetric(StrEnum): Covers general column ratios (null / distinct / filled), type-specific statistics (length, numeric range, date range, true count), table-level - record count, and table-group rollups (profiling score, hygiene count). + row count, and table-group rollups (profiling score, hygiene issues). + + Labels align with the field names in ``column_profile_fields_resource``. """ # Apply to any column @@ -104,22 +106,22 @@ class ProfileMetric(StrEnum): DISTINCT_RATIO = "Distinct Ratio" FILLED_RATIO = "Filled Ratio" # Apply to the parent table - RECORD_COUNT = "Record Count" + RECORD_COUNT = "Row Count" # Apply to the whole table group PROFILING_SCORE = "Profiling Score" - HYGIENE_COUNT = "Hygiene Count" + HYGIENE_COUNT = "Hygiene Issues" # Alpha-only - MIN_LENGTH = "Min Length" - MAX_LENGTH = "Max Length" - AVG_LENGTH = "Avg Length" + MIN_LENGTH = "Minimum Length" + MAX_LENGTH = "Maximum Length" + AVG_LENGTH = "Average Length" # Numeric-only - MIN = "Min" - MAX = "Max" - AVG = "Avg" - STDEV = "Stdev" + MIN = "Minimum Value" + MAX = "Maximum Value" + AVG = "Average Value" + STDEV = "Standard Deviation" # Date-only - MIN_DATE = "Min Date" - MAX_DATE = "Max Date" + MIN_DATE = "Minimum Date" + MAX_DATE = "Maximum Date" # Boolean-only TRUE_COUNT = "True Count" diff --git a/testgen/common/models/profiling_run.py b/testgen/common/models/profiling_run.py index c37bd407..9faf8bf0 100644 --- a/testgen/common/models/profiling_run.py +++ b/testgen/common/models/profiling_run.py @@ -11,7 +11,7 @@ from sqlalchemy.orm.attributes import flag_modified from sqlalchemy.sql.expression import case -from testgen.common.enums import JobStatus +from testgen.common.enums import Disposition, JobStatus from testgen.common.models import get_current_session from testgen.common.models.connection import Connection from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal @@ -450,7 +450,7 @@ def count_confirmed_hygiene_issues(cls, run_ids: list[UUID]) -> dict[UUID, int]: select(HygieneIssue.profile_run_id, func.count()) .where( HygieneIssue.profile_run_id.in_(run_ids), - func.coalesce(HygieneIssue.disposition, "Confirmed") == "Confirmed", + func.coalesce(HygieneIssue.disposition, Disposition.CONFIRMED) == Disposition.CONFIRMED, ) .group_by(HygieneIssue.profile_run_id) ) diff --git a/testgen/mcp/tools/profile_history.py b/testgen/mcp/tools/profile_history.py index a4eb7897..19614365 100644 --- a/testgen/mcp/tools/profile_history.py +++ b/testgen/mcp/tools/profile_history.py @@ -16,11 +16,13 @@ from sqlalchemy import func -from testgen.common.models import with_database_session +from testgen.common.enums import Disposition, JobStatus +from testgen.common.models import get_current_session, with_database_session from testgen.common.models.data_column import ProfileMetric from testgen.common.models.hygiene_issue import HygieneIssue, HygieneIssueType +from testgen.common.models.job_execution import JobExecution from testgen.common.models.profile_result import ProfileResult -from testgen.common.models.profiling_run import ProfilingRun +from testgen.common.models.profiling_run import ProfilingRun, ProfilingRunSummary from testgen.mcp.exceptions import MCPUserError from testgen.mcp.permissions import mcp_permission from testgen.mcp.tools.common import ( @@ -123,7 +125,7 @@ def _column_metric_value(metric: ProfileMetric, pr: ProfileResult | None) -> obj """Extract a column-scope metric value from a ProfileResult row. Returns ``None`` if the row is missing or the metric doesn't apply to the - column's ``general_type`` (e.g. ``Avg Length`` on a numeric column). + column's ``general_type`` (e.g. ``Average Length`` on a numeric column). """ if pr is None: return None @@ -198,13 +200,14 @@ def _delta_cell(metric: ProfileMetric, baseline: object | None, target: object | # Run-state guard # --------------------------------------------------------------------------- -_REQUIRED_RUN_STATUS = "Complete" - def _require_completed(run: ProfilingRun, label: str) -> None: - if run.status != _REQUIRED_RUN_STATUS: + """Raise if the run's job execution isn't completed.""" + je = get_current_session().get(JobExecution, run.job_execution_id) + if je.status != JobStatus.COMPLETED: + status_label = ProfilingRunSummary.STATUS_LABEL.get(je.status, je.status) raise MCPUserError( - f"{label} run is in `{run.status}` state — comparison requires a completed run." + f"{label} run is in `{status_label}` state — comparison requires a completed run." ) @@ -294,9 +297,12 @@ def compare_profiling_runs( snapshot). When omitted, defaults to the previous completed run on the same table group. table_name: Optional — restrict the comparison to one table (case-sensitive). - column_name: Optional — restrict the comparison to one column (requires - `table_name` when used in the diff render but accepted independently). + column_name: Optional — restrict the comparison to one column (case-sensitive); requires + `table_name`. """ + if column_name is not None and table_name is None: + raise MCPUserError("`column_name` requires `table_name`.") + target_run = resolve_profiling_run(target_job_execution_id) _require_completed(target_run, "Target") @@ -309,13 +315,13 @@ def compare_profiling_runs( ) else: baseline_run = resolve_profiling_run(baseline_job_execution_id) - _require_completed(baseline_run, "Baseline") if baseline_run.table_groups_id != target_run.table_groups_id: raise MCPUserError( "Both runs must belong to the same table group to be comparable. " f"Target is in table group `{target_run.table_groups_id}`, " f"baseline is in table group `{baseline_run.table_groups_id}`." ) + _require_completed(baseline_run, "Baseline") rows = ProfileResult.select_for_runs( run_ids=[target_run.id, baseline_run.id], @@ -361,7 +367,7 @@ def _diff_hygiene_issues( """ clauses = [ HygieneIssue.profile_run_id.in_([target_run_id, baseline_run_id]), - func.coalesce(HygieneIssue.disposition, "Confirmed") == "Confirmed", + func.coalesce(HygieneIssue.disposition, Disposition.CONFIRMED) == Disposition.CONFIRMED, ] if table_name is not None: clauses.append(HygieneIssue.table_name == table_name) @@ -537,19 +543,20 @@ def get_profiling_trends( """Show a time series of caller-named profiling metrics across recent completed runs of a table group. Metric scope rules: - - Column-level metrics (e.g. `Null Ratio`, `Avg Length`, `Min`) require both + - Column-level metrics (e.g. `Null Ratio`, `Average Length`, `Minimum Value`) require both `table_name` and `column_name`. - - `Record Count` is table-level and requires `table_name`. - - `Profiling Score` and `Hygiene Count` are table-group-level and accept any scope. + - `Row Count` is table-level and requires `table_name`. + - `Profiling Score` and `Hygiene Issues` are table-group-level and accept any scope. - Type-specific metrics return `—` for runs where the column's general type - didn't match (e.g. `Min` on a column that was Alpha in an earlier run). + didn't match (e.g. `Minimum Value` on a column that was Alpha in an earlier run). Args: table_group_id: UUID of the table group, e.g. from `get_data_inventory`. metrics: One or more metric names. Accepted values: `Null Ratio`, `Distinct Ratio`, - `Filled Ratio`, `Record Count`, `Profiling Score`, `Hygiene Count`, - `Min Length`, `Max Length`, `Avg Length`, `Min`, `Max`, `Avg`, `Stdev`, - `Min Date`, `Max Date`, `True Count`. + `Filled Ratio`, `Row Count`, `Profiling Score`, `Hygiene Issues`, + `Minimum Length`, `Maximum Length`, `Average Length`, `Minimum Value`, + `Maximum Value`, `Average Value`, `Standard Deviation`, `Minimum Date`, + `Maximum Date`, `True Count`. table_name: Optional — restrict to one table (case-sensitive). column_name: Optional — restrict to one column (case-sensitive); requires `table_name`. @@ -844,10 +851,12 @@ def _format_column_delta( baseline_names = set(baseline_cols) for name in sorted(target_names - baseline_names): snap = target_cols[name] - out.append(f"column `{name}` added ({snap.column_type or snap.db_data_type or '—'})") + type_label = snap.column_type or snap.db_data_type + out.append(f"column `{name}` added ({type_label})" if type_label else f"column `{name}` added") for name in sorted(baseline_names - target_names): snap = baseline_cols[name] - out.append(f"column `{name}` dropped (was {snap.column_type or snap.db_data_type or '—'})") + type_label = snap.column_type or snap.db_data_type + out.append(f"column `{name}` dropped (was {type_label})" if type_label else f"column `{name}` dropped") for name in sorted(target_names & baseline_names): target_col = target_cols[name] baseline_col = baseline_cols[name] diff --git a/testgen/mcp/tools/profiling.py b/testgen/mcp/tools/profiling.py index 2f01de89..d8ac35a2 100644 --- a/testgen/mcp/tools/profiling.py +++ b/testgen/mcp/tools/profiling.py @@ -909,8 +909,8 @@ def _render_boolean_block(doc: MdDoc, p: dict) -> None: true_ct = p.get("boolean_true_ct") or 0 value_ct = p.get("value_ct") or 0 false_ct = max(value_ct - true_ct, 0) - doc.field("True", true_ct) - doc.field("False", false_ct) + doc.field("True Count", true_ct) + doc.field("False Count", false_ct) def _render_unknown_block(doc: MdDoc, p: dict) -> None: diff --git a/testgen/mcp/tools/reference.py b/testgen/mcp/tools/reference.py index 655c257e..210e4b89 100644 --- a/testgen/mcp/tools/reference.py +++ b/testgen/mcp/tools/reference.py @@ -208,8 +208,8 @@ def column_profile_fields_resource() -> str: Populated when `General Type == "Boolean"`. -- **True** — Rows where the value is true (count). -- **False** — Rows where the value is false (count, derived as `Value Count - True`). +- **True Count** — Rows where the value is true (count). +- **False Count** — Rows where the value is false (count, derived as `Value Count - True Count`). ## PII Redaction diff --git a/tests/unit/mcp/test_tools_profile_history.py b/tests/unit/mcp/test_tools_profile_history.py index f40a5187..6bec4fe1 100644 --- a/tests/unit/mcp/test_tools_profile_history.py +++ b/tests/unit/mcp/test_tools_profile_history.py @@ -4,6 +4,7 @@ import pytest +from testgen.common.enums import JobStatus from testgen.common.models.data_column import ProfileMetric from testgen.mcp.exceptions import MCPUserError from testgen.mcp.tools.profile_history import ( @@ -16,6 +17,20 @@ get_schema_history, ) + +def _je(status=JobStatus.COMPLETED): + """Build a JobExecution mock for ``session.get(JobExecution, ...)`` returns.""" + je = MagicMock() + je.status = status + return je + + +def _patch_session(jes): + """Patch ``get_current_session`` so ``session.get(JobExecution, ...)`` returns the given JEs in order.""" + session = MagicMock() + session.get.side_effect = jes + return patch("testgen.mcp.tools.profile_history.get_current_session", return_value=session) + # ---------------------------------------------------------------------- # Helpers # ---------------------------------------------------------------------- @@ -267,7 +282,8 @@ def test_compare_profiling_runs_auto_baseline( mock_iss.select_where.return_value = [] mock_iss_type.select_where.return_value = [] - result = compare_profiling_runs(str(target_run.job_execution_id)) + with _patch_session([_je(), _je()]): + result = compare_profiling_runs(str(target_run.job_execution_id)) assert "Profiling Run Comparison" in result assert "Target" in result and "Baseline" in result @@ -277,20 +293,22 @@ def test_compare_profiling_runs_auto_baseline( @patch("testgen.mcp.tools.profile_history.resolve_profiling_run") def test_compare_profiling_runs_rejects_non_completed_target(mock_resolve, db_session_mock): - target_run = _profiling_run(status="Running") + target_run = _profiling_run() mock_resolve.return_value = target_run - with pytest.raises(MCPUserError, match="Target run is in `Running` state"): - compare_profiling_runs(str(target_run.job_execution_id)) + with _patch_session([_je(status=JobStatus.RUNNING)]): + with pytest.raises(MCPUserError, match="Target run is in `Running` state"): + compare_profiling_runs(str(target_run.job_execution_id)) @patch("testgen.mcp.tools.profile_history.resolve_profiling_run") -def test_compare_profiling_runs_rejects_cancelled_target(mock_resolve, db_session_mock): - target_run = _profiling_run(status="Cancelled") +def test_compare_profiling_runs_rejects_canceled_target(mock_resolve, db_session_mock): + target_run = _profiling_run() mock_resolve.return_value = target_run - with pytest.raises(MCPUserError, match="`Cancelled`"): - compare_profiling_runs(str(target_run.job_execution_id)) + with _patch_session([_je(status=JobStatus.CANCELED)]): + with pytest.raises(MCPUserError, match="`Canceled`"): + compare_profiling_runs(str(target_run.job_execution_id)) @patch("testgen.mcp.tools.profile_history.resolve_profiling_run") @@ -299,11 +317,17 @@ def test_compare_profiling_runs_rejects_cross_table_group(mock_resolve, db_sessi baseline_run = _profiling_run(table_groups_id=uuid4()) mock_resolve.side_effect = [target_run, baseline_run] - with pytest.raises(MCPUserError, match="same table group"): - compare_profiling_runs( - str(target_run.job_execution_id), - str(baseline_run.job_execution_id), - ) + with _patch_session([_je()]): + with pytest.raises(MCPUserError, match="same table group"): + compare_profiling_runs( + str(target_run.job_execution_id), + str(baseline_run.job_execution_id), + ) + + +def test_compare_profiling_runs_column_requires_table(db_session_mock): + with pytest.raises(MCPUserError, match="`column_name` requires `table_name`"): + compare_profiling_runs(str(uuid4()), column_name="email") @patch("testgen.mcp.tools.profile_history.resolve_profiling_run") @@ -312,8 +336,9 @@ def test_compare_profiling_runs_auto_baseline_first_run(mock_resolve, db_session target_run.get_previous.return_value = None mock_resolve.return_value = target_run - with pytest.raises(MCPUserError, match="first completed profiling run"): - compare_profiling_runs(str(target_run.job_execution_id)) + with _patch_session([_je()]): + with pytest.raises(MCPUserError, match="first completed profiling run"): + compare_profiling_runs(str(target_run.job_execution_id)) @patch("testgen.mcp.tools.profile_history.HygieneIssue") @@ -335,7 +360,8 @@ def test_compare_profiling_runs_identical_runs_renders_no_changes( mock_iss.select_where.return_value = [] mock_iss_type.select_where.return_value = [] - result = compare_profiling_runs(str(target_run.job_execution_id)) + with _patch_session([_je(), _je()]): + result = compare_profiling_runs(str(target_run.job_execution_id)) assert "No changes between target and baseline" in result From 354aa95d76bd7107827778bb0c031c05885ad72f Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Thu, 2 Apr 2026 18:51:30 -0400 Subject: [PATCH 26/58] feat(salesforce): add Salesforce Data 360 flavor Implement Salesforce Data 360 (Data Cloud) as a profiling and testing target. Data 360 speaks HTTP+OAuth and exposes the Tableau Hyper engine through an API, so the flavor wraps salesforce-cdp-connector via a custom SQLAlchemy dialect rather than a wire-protocol driver. - Flavor service + dialect that surface client-credentials and JWT-bearer authentication, with the table-group form labelling the schema field as "Data Space" - Schema discovery via the metadata API since Data 360 has no information_schema - 51 test-type and 32 hygiene-issue SQL templates ported to the Hyper dialect - Default max_query_chars lowered to 15000 for this flavor to stay within Hyper's expression-depth limit on aggregated CAT runs - Unit tests for the flavor service and refresh-data-chars path TG-994 Co-Authored-By: Claude Opus 4.7 (1M context) --- pyproject.toml | 4 + .../commands/queries/execute_tests_query.py | 3 +- testgen/commands/queries/profiling_query.py | 2 +- .../queries/refresh_data_chars_query.py | 54 ++- testgen/commands/run_profiling.py | 2 +- testgen/commands/run_refresh_data_chars.py | 21 +- testgen/commands/run_test_validation.py | 59 ++- testgen/common/database/column_chars.py | 17 + .../common/database/flavor/flavor_service.py | 23 +- .../salesforce_data360_flavor_service.py | 128 ++++++ .../database/salesforce_data360_dialect.py | 169 +++++++ testgen/common/models/connection.py | 3 +- testgen/common/source_data_service.py | 6 +- .../030_initialize_new_schema_structure.sql | 2 +- ..._anomaly_types_Boolean_Value_Mismatch.yaml | 8 + ...anomaly_types_Char_Column_Date_Values.yaml | 8 + ...omaly_types_Char_Column_Number_Values.yaml | 8 + ...anomaly_types_Column_Pattern_Mismatch.yaml | 10 +- ...anomaly_types_Delimited_Data_Embedded.yaml | 8 + ...ile_anomaly_types_Inconsistent_Casing.yaml | 14 + ...rofile_anomaly_types_Invalid_Zip3_USA.yaml | 8 + ...profile_anomaly_types_Invalid_Zip_USA.yaml | 8 + .../profile_anomaly_types_Leading_Spaces.yaml | 8 + ...le_anomaly_types_Multiple_Types_Major.yaml | 7 + ...le_anomaly_types_Multiple_Types_Minor.yaml | 7 + .../profile_anomaly_types_No_Values.yaml | 8 + ..._anomaly_types_Non_Alpha_Name_Address.yaml | 10 + ...anomaly_types_Non_Alpha_Prefixed_Name.yaml | 10 + ...file_anomaly_types_Non_Printing_Chars.yaml | 31 ++ ...ile_anomaly_types_Non_Standard_Blanks.yaml | 8 + ...le_anomaly_types_Potential_Duplicates.yaml | 8 + .../profile_anomaly_types_Potential_PII.yaml | 8 + .../profile_anomaly_types_Quoted_Values.yaml | 8 + ...rofile_anomaly_types_Recency_One_Year.yaml | 7 + ...file_anomaly_types_Recency_Six_Months.yaml | 7 + ...nomaly_types_Small_Divergent_Value_Ct.yaml | 8 + ..._anomaly_types_Small_Missing_Value_Ct.yaml | 8 + ..._anomaly_types_Small_Numeric_Value_Ct.yaml | 8 + ...maly_types_Standardized_Value_Matches.yaml | 8 + .../profile_anomaly_types_Suggested_Type.yaml | 8 + ..._anomaly_types_Table_Pattern_Mismatch.yaml | 7 + ...ofile_anomaly_types_Unexpected_Emails.yaml | 8 + ...le_anomaly_types_Unexpected_US_States.yaml | 8 + ...le_anomaly_types_Unlikely_Date_Values.yaml | 8 + ...le_anomaly_types_Variant_Coded_Values.yaml | 8 + .../test_types_Aggregate_Balance.yaml | 74 ++++ .../test_types_Aggregate_Balance_Percent.yaml | 76 ++++ .../test_types_Aggregate_Balance_Range.yaml | 76 ++++ .../test_types_Aggregate_Minimum.yaml | 74 ++++ .../test_types_Alpha_Trunc.yaml | 16 + .../test_types_Avg_Shift.yaml | 16 + .../dbsetup_test_types/test_types_CUSTOM.yaml | 39 ++ .../test_types_Combo_Match.yaml | 66 +++ .../test_types_Condition_Flag.yaml | 16 + .../test_types_Constant.yaml | 16 + .../test_types_Daily_Record_Ct.yaml | 16 + .../test_types_Dec_Trunc.yaml | 16 + .../test_types_Distinct_Date_Ct.yaml | 16 + .../test_types_Distinct_Value_Ct.yaml | 16 + .../test_types_Distribution_Shift.yaml | 114 ++++- .../test_types_Dupe_Rows.yaml | 52 +++ .../test_types_Email_Format.yaml | 16 + .../test_types_Freshness_Trend.yaml | 50 +++ .../test_types_Future_Date.yaml | 16 + .../test_types_Future_Date_1Y.yaml | 16 + .../test_types_Incr_Avg_Shift.yaml | 16 + .../test_types_LOV_All.yaml | 16 + .../test_types_LOV_Match.yaml | 16 + .../test_types_Metric_Trend.yaml | 19 + .../test_types_Min_Date.yaml | 16 + .../test_types_Min_Val.yaml | 16 + .../test_types_Missing_Pct.yaml | 16 + .../test_types_Monthly_Rec_Ct.yaml | 16 + .../test_types_Outlier_Pct_Above.yaml | 16 + .../test_types_Outlier_Pct_Below.yaml | 16 + .../test_types_Pattern_Match.yaml | 16 + .../test_types_Recency.yaml | 16 + .../test_types_Required.yaml | 16 + .../dbsetup_test_types/test_types_Row_Ct.yaml | 15 + .../test_types_Row_Ct_Pct.yaml | 16 + .../test_types_Schema_Drift.yaml | 55 +++ .../test_types_Street_Addr_Pattern.yaml | 16 + .../test_types_Table_Freshness.yaml | 32 ++ .../test_types_Timeframe_Combo_Gain.yaml | 66 +++ .../test_types_Timeframe_Combo_Match.yaml | 93 ++++ .../test_types_US_State.yaml | 16 + .../dbsetup_test_types/test_types_Unique.yaml | 16 + .../test_types_Unique_Pct.yaml | 16 + .../test_types_Valid_Characters.yaml | 16 + .../test_types_Valid_Month.yaml | 8 + .../test_types_Valid_US_Zip.yaml | 16 + .../test_types_Valid_US_Zip3.yaml | 16 + .../test_types_Variability_Decrease.yaml | 16 + .../test_types_Variability_Increase.yaml | 16 + .../test_types_Volume_Trend.yaml | 19 + .../test_types_Weekly_Rec_Ct.yaml | 16 + .../dbupgrade/0190_incremental_upgrade.sql | 4 + .../gen_query_tests/gen_Dupe_Rows.sql | 55 +++ .../gen_query_tests/gen_Freshness_Trend.sql | 210 +++++++++ .../gen_query_tests/gen_Table_Freshness.sql | 189 ++++++++ .../profiling/project_profiling_query.sql | 247 +++++++++++ .../project_secondary_profiling_query.sql | 37 ++ .../profiling/templated_functions.yaml | 98 +++++ .../ui/assets/flavors/salesforce_data360.svg | 83 ++++ .../frontend/js/pages/table_group_list.js | 8 +- .../frontend/js/pages/test_definitions.js | 21 +- testgen/ui/queries/table_group_queries.py | 13 +- .../static/js/components/connection_form.js | 189 +++++++- .../static/js/components/table_group_form.js | 32 +- .../js/components/test_definition_form.js | 12 +- testgen/ui/views/connections.py | 6 + testgen/ui/views/data_catalog.py | 9 +- testgen/ui/views/test_definitions.py | 31 +- .../queries/test_refresh_data_chars_query.py | 90 ++++ .../common/test_salesforce_data360_flavor.py | 415 ++++++++++++++++++ 115 files changed, 3874 insertions(+), 118 deletions(-) create mode 100644 testgen/common/database/column_chars.py create mode 100644 testgen/common/database/flavor/salesforce_data360_flavor_service.py create mode 100644 testgen/common/database/salesforce_data360_dialect.py create mode 100644 testgen/template/dbupgrade/0190_incremental_upgrade.sql create mode 100644 testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Dupe_Rows.sql create mode 100644 testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Freshness_Trend.sql create mode 100644 testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Table_Freshness.sql create mode 100644 testgen/template/flavors/salesforce_data360/profiling/project_profiling_query.sql create mode 100644 testgen/template/flavors/salesforce_data360/profiling/project_secondary_profiling_query.sql create mode 100644 testgen/template/flavors/salesforce_data360/profiling/templated_functions.yaml create mode 100644 testgen/ui/assets/flavors/salesforce_data360.svg create mode 100644 tests/unit/common/test_salesforce_data360_flavor.py diff --git a/pyproject.toml b/pyproject.toml index 3cae9aed..49262388 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dependencies = [ "oracledb==3.4.0", "hdbcli==2.25.31", "sqlalchemy-hana==4.4.0", + "salesforce-cdp-connector>=1.0.19", "pyodbc==5.2.0", "psycopg2-binary==2.9.11", "pycryptodome==3.21", @@ -117,6 +118,9 @@ release = [ testgen = "testgen.__main__:cli" tg-patch-streamlit = "testgen.ui.scripts.patch_streamlit:patch" +[project.entry-points."sqlalchemy.dialects"] +salesforce_data360 = "testgen.common.database.salesforce_data360_dialect:SalesforceData360Dialect" + [project.urls] "Source Code" = "https://github.com/DataKitchen/dataops-testgen" "Bug Tracker" = "https://github.com/DataKitchen/dataops-testgen/issues" diff --git a/testgen/commands/queries/execute_tests_query.py b/testgen/commands/queries/execute_tests_query.py index 794ffcdb..03eab489 100644 --- a/testgen/commands/queries/execute_tests_query.py +++ b/testgen/commands/queries/execute_tests_query.py @@ -510,7 +510,6 @@ def aggregate_cat_tests( ) -> tuple[list[tuple[str, None]], list[list[TestExecutionDef]]]: varchar_type = self.flavor_service.varchar_type concat_operator = self.flavor_service.concat_operator - quote = self.flavor_service.quote_character for td in test_defs: # Don't recalculate expressions if it was already done before @@ -545,7 +544,7 @@ def aggregate_cat_tests( f"SELECT {len(aggregate_queries)} AS query_index, " f"{concat_operator.join([td.measure_expression for td in group])} AS result_measures, " f"{concat_operator.join([td.condition_expression for td in group])} AS result_codes " - f"FROM {quote}{group[0].schema_name}{quote}.{quote}{group[0].table_name}{quote}" + f"FROM {self.flavor_service.get_table_ref(group[0].schema_name, group[0].table_name)}" ) query = query.replace(":", "\\:") diff --git a/testgen/commands/queries/profiling_query.py b/testgen/commands/queries/profiling_query.py index d3f02a16..f7a72aac 100644 --- a/testgen/commands/queries/profiling_query.py +++ b/testgen/commands/queries/profiling_query.py @@ -1,8 +1,8 @@ import dataclasses from uuid import UUID -from testgen.commands.queries.refresh_data_chars_query import ColumnChars from testgen.common import read_template_sql_file +from testgen.common.database.column_chars import ColumnChars from testgen.common.database.database_service import process_conditionals, replace_params from testgen.common.models.connection import Connection from testgen.common.models.profiling_run import ProfilingRun diff --git a/testgen/commands/queries/refresh_data_chars_query.py b/testgen/commands/queries/refresh_data_chars_query.py index b494c308..762e7261 100644 --- a/testgen/commands/queries/refresh_data_chars_query.py +++ b/testgen/commands/queries/refresh_data_chars_query.py @@ -1,28 +1,20 @@ -import dataclasses +import re from collections.abc import Iterable from datetime import datetime from testgen.common import read_template_sql_file +from testgen.common.database.column_chars import ColumnChars from testgen.common.database.database_service import get_flavor_service, replace_params from testgen.common.models.connection import Connection from testgen.common.models.table_group import TableGroup from testgen.utils import chunk_queries, to_sql_timestamp -@dataclasses.dataclass -class ColumnChars: - schema_name: str - table_name: str - column_name: str - ordinal_position: int = None - general_type: str = None - column_type: str = None - db_data_type: str = None - is_decimal: bool = False - approx_record_ct: int = None - # This should not default to 0 since we don't always retrieve actual row counts - # UI relies on the null value to know that the approx_record_ct should be displayed instead - record_ct: int = None +def _like_to_regex(pattern: str) -> re.Pattern[str]: + # Mirrors SQL LIKE semantics used in _get_table_criteria: `%` is the only + # wildcard; `_` is treated as a literal character (escaped to `\_` in the + # SQL path). Anything else is literal. + return re.compile("^" + re.escape(pattern.strip()).replace("%", ".*") + "$") class RefreshDataCharsSQL: @@ -100,6 +92,32 @@ def _get_table_criteria(self) -> str: return table_criteria + def filter_schema_columns(self, columns: list[ColumnChars]) -> list[ColumnChars]: + """Apply the table group's filters (table set, include/exclude masks) to a column list. + + Mirrors `_get_table_criteria` for flavors that bypass the SQL template path + (e.g., Salesforce Data 360, where columns come from the metadata API). + """ + result = columns + + if self.table_group.profiling_table_set: + allowed = {item.strip() for item in self.table_group.profiling_table_set.split(",")} + result = [c for c in result if c.table_name in allowed] + + if self.table_group.profiling_include_mask: + include_patterns = [ + _like_to_regex(item) for item in self.table_group.profiling_include_mask.split(",") + ] + result = [c for c in result if any(p.match(c.table_name) for p in include_patterns)] + + if self.table_group.profiling_exclude_mask: + exclude_patterns = [ + _like_to_regex(item) for item in self.table_group.profiling_exclude_mask.split(",") + ] + result = [c for c in result if not any(p.match(c.table_name) for p in exclude_patterns)] + + return result + def get_schema_ddf(self) -> tuple[str, dict]: # Runs on Target database return self._get_query( @@ -111,9 +129,8 @@ def get_schema_ddf(self) -> tuple[str, dict]: def get_row_counts(self, table_names: Iterable[str]) -> list[tuple[str, None]]: # Runs on Target database schema = self.table_group.table_group_schema - quote = self.flavor_service.quote_character count_queries = [ - f"SELECT '{table}' AS table_name, COUNT(*) AS row_count FROM {quote}{schema}{quote}.{quote}{table}{quote}" + f"SELECT '{table}' AS table_name, COUNT(*) AS row_count FROM {self.flavor_service.get_table_ref(schema, table)}" for table in table_names ] chunked_queries = chunk_queries(count_queries, " UNION ALL ", self.connection.max_query_chars) @@ -122,8 +139,7 @@ def get_row_counts(self, table_names: Iterable[str]) -> list[tuple[str, None]]: def verify_access(self, table_name: str) -> tuple[str, None]: # Runs on Target database schema = self.table_group.table_group_schema - quote = self.flavor_service.quote_character - table_ref = f"{quote}{schema}{quote}.{quote}{table_name}{quote}" + table_ref = self.flavor_service.get_table_ref(schema, table_name) prefix, suffix = self.flavor_service.row_limit_clauses(1) query = f"SELECT {prefix} 1 FROM {table_ref} {suffix}".strip() return (query, None) diff --git a/testgen/commands/run_profiling.py b/testgen/commands/run_profiling.py index 2125defc..ffd1a58f 100644 --- a/testgen/commands/run_profiling.py +++ b/testgen/commands/run_profiling.py @@ -9,7 +9,6 @@ TableSampling, calculate_sampling_params, ) -from testgen.commands.queries.refresh_data_chars_query import ColumnChars from testgen.commands.run_refresh_data_chars import run_data_chars_refresh from testgen.commands.test_generation import run_monitor_generation, run_test_generation from testgen.common import ( @@ -19,6 +18,7 @@ set_target_db_params, write_to_app_db, ) +from testgen.common.database.column_chars import ColumnChars from testgen.common.database.database_service import ThreadedProgress from testgen.common.job_context import job_context from testgen.common.mixpanel_service import MixpanelService diff --git a/testgen/commands/run_refresh_data_chars.py b/testgen/commands/run_refresh_data_chars.py index 94f9b3e0..c4b483ce 100644 --- a/testgen/commands/run_refresh_data_chars.py +++ b/testgen/commands/run_refresh_data_chars.py @@ -1,13 +1,15 @@ import logging from datetime import datetime -from testgen.commands.queries.refresh_data_chars_query import ColumnChars, RefreshDataCharsSQL +from testgen.commands.queries.refresh_data_chars_query import RefreshDataCharsSQL +from testgen.common.database.column_chars import ColumnChars from testgen.common.database.database_service import ( execute_db_queries, fetch_dict_from_db, fetch_from_db_threaded, write_to_app_db, ) +from testgen.common.database.flavor.flavor_service import resolve_connection_params from testgen.common.models.connection import Connection from testgen.common.models.table_group import TableGroup from testgen.utils import get_exception_message @@ -20,16 +22,25 @@ def run_data_chars_refresh(connection: Connection, table_group: TableGroup, run_ LOG.info("Getting DDF for table group") try: - data_chars = fetch_dict_from_db(*sql_generator.get_schema_ddf(), use_target_db=True) + if sql_generator.flavor_service.metadata_via_api: + # Flavor returns column metadata directly (e.g., Salesforce Data 360 + # via the connector's metadata API). These flavors have no information_schema. + # Apply the table-group filters in Python + # since we bypass the SQL {TABLE_CRITERIA} clause. + params = resolve_connection_params(connection.__dict__) + api_columns = sql_generator.flavor_service.get_schema_columns(params, table_group.table_group_schema) or [] + data_chars = sql_generator.filter_schema_columns(api_columns) + else: + rows = fetch_dict_from_db(*sql_generator.get_schema_ddf(), use_target_db=True) + data_chars = [ColumnChars(**row) for row in rows] except Exception as e: raise RuntimeError(f"Error refreshing columns for data catalog. {get_exception_message(e)}") from e - - data_chars = [ColumnChars(**column) for column in data_chars] + if data_chars: distinct_tables = {column.table_name for column in data_chars} LOG.info(f"Tables: {len(distinct_tables)}, Columns: {len(data_chars)}") count_queries = sql_generator.get_row_counts(distinct_tables) - + LOG.info("Getting row counts for table group") count_results, _, error_data = fetch_from_db_threaded( count_queries, use_target_db=True, max_threads=connection.max_threads, diff --git a/testgen/commands/run_test_validation.py b/testgen/commands/run_test_validation.py index db247676..2c1a89a0 100644 --- a/testgen/commands/run_test_validation.py +++ b/testgen/commands/run_test_validation.py @@ -4,7 +4,9 @@ from testgen.commands.queries.execute_tests_query import TestExecutionDef, TestExecutionSQL from testgen.common import execute_db_queries, fetch_dict_from_db +from testgen.common.database.column_chars import ColumnChars from testgen.common.database.database_service import write_to_app_db +from testgen.common.database.flavor.flavor_service import resolve_connection_params LOG = logging.getLogger("testgen") @@ -79,6 +81,47 @@ def add_error(test_id: UUID, error: str) -> None: return identifiers_to_check, target_schemas, errors +def get_target_identifiers( + sql_generator: TestExecutionSQL, + target_schemas: set[str], +) -> tuple[set[tuple[str, str]], set[tuple[str, str, str]]]: + """Fetch (schema, table) and (schema, table, column) sets for validation. + + Flavors with ``metadata_via_api=True`` (e.g., Salesforce Data 360) + use ``get_schema_columns()`` — these flavors have no ``information_schema``. + Other flavors use the SQL template path. + """ + flavor_service = sql_generator.flavor_service + + if flavor_service.metadata_via_api: + params = resolve_connection_params(sql_generator.connection.__dict__) + api_columns: list[ColumnChars] = [] + for schema in target_schemas: + cols = flavor_service.get_schema_columns(params, schema) or [] + api_columns.extend(cols) + LOG.info("Got tables and columns from flavor metadata API for validation") + target_tables = {(c.schema_name.lower(), c.table_name.lower()) for c in api_columns} + target_columns = { + (c.schema_name.lower(), c.table_name.lower(), c.column_name.lower()) for c in api_columns + } + return target_tables, target_columns + + LOG.info("Getting tables and columns in target schemas for validation") + target_identifiers = fetch_dict_from_db( + *sql_generator.get_target_identifiers(target_schemas), + use_target_db=True, + ) + if not target_identifiers: + LOG.info("No tables or columns present in target schemas") + + target_tables = {(item["schema_name"].lower(), item["table_name"].lower()) for item in target_identifiers} + target_columns = { + (item["schema_name"].lower(), item["table_name"].lower(), item["column_name"].lower()) + for item in target_identifiers + } + return target_tables, target_columns + + def check_identifiers( identifiers_to_check: dict[tuple[str, str, str | None], set[UUID]], target_tables: set[tuple[str, str]], @@ -130,21 +173,7 @@ def run_test_validation( test_defs_by_id[test_id].errors = error_list if target_schemas: - LOG.info("Getting tables and columns in target schemas for validation") - target_identifiers = fetch_dict_from_db( - *sql_generator.get_target_identifiers(target_schemas), - use_target_db=True, - ) - if not target_identifiers: - LOG.info("No tables or columns present in target schemas") - - # Normalize identifiers before validating - target_tables = {(item["schema_name"].lower(), item["table_name"].lower()) for item in target_identifiers} - target_columns = { - (item["schema_name"].lower(), item["table_name"].lower(), item["column_name"].lower()) - for item in target_identifiers - } - + target_tables, target_columns = get_target_identifiers(sql_generator, target_schemas) check_errors = check_identifiers(identifiers_to_check, target_tables, target_columns) for test_id, error_list in check_errors.items(): if not test_defs_by_id[test_id].errors: diff --git a/testgen/common/database/column_chars.py b/testgen/common/database/column_chars.py new file mode 100644 index 00000000..6faa08f7 --- /dev/null +++ b/testgen/common/database/column_chars.py @@ -0,0 +1,17 @@ +import dataclasses + + +@dataclasses.dataclass +class ColumnChars: + schema_name: str + table_name: str + column_name: str + ordinal_position: int | None = None + general_type: str | None = None + column_type: str | None = None + db_data_type: str | None = None + is_decimal: bool = False + approx_record_ct: int | None = None + # This should not default to 0 since we don't always retrieve actual row counts + # UI relies on the null value to know that the approx_record_ct should be displayed instead + record_ct: int | None = None diff --git a/testgen/common/database/flavor/flavor_service.py b/testgen/common/database/flavor/flavor_service.py index 0f4f576b..becb8b38 100644 --- a/testgen/common/database/flavor/flavor_service.py +++ b/testgen/common/database/flavor/flavor_service.py @@ -6,9 +6,10 @@ from sqlalchemy import create_engine as sqlalchemy_create_engine from sqlalchemy.engine.base import Engine +from testgen.common.database.column_chars import ColumnChars from testgen.common.encrypt import DecryptText -SQLFlavor = Literal["redshift", "redshift_spectrum", "snowflake", "mssql", "postgresql", "databricks", "bigquery", "oracle", "sap_hana"] +SQLFlavor = Literal["redshift", "redshift_spectrum", "snowflake", "mssql", "postgresql", "databricks", "bigquery", "oracle", "sap_hana", "salesforce_data360"] RowLimitingClause = Literal["limit", "top", "fetch"] @@ -108,10 +109,29 @@ def row_limit_clauses(self, n: int) -> tuple[str, str]: if self.row_limiting_clause == "fetch": return "", f"FETCH FIRST {n} ROWS ONLY" return "", f"LIMIT {n}" + default_uppercase = False test_query = "SELECT 1" url_scheme = "postgresql" + qualifies_table_refs_with_schema = True + metadata_via_api = False + + def get_schema_columns(self, _params: ResolvedConnectionParams, _schema: str) -> list[ColumnChars] | None: + """Return column metadata without querying information_schema. + + Override this for flavors that lack information_schema and set ``metadata_via_api = True``. + Return None to use the standard SQL template path. + """ + return None + + def get_table_ref(self, schema: str, table: str) -> str: + """Return a fully-qualified table reference for SQL queries.""" + q = self.quote_character + if not self.qualifies_table_refs_with_schema: + return f"{q}{table}{q}" + return f"{q}{schema}{q}.{q}{table}{q}" + def get_pre_connection_queries(self, params: ResolvedConnectionParams) -> list[tuple[str, dict | None]]: # noqa: ARG002 return [] @@ -142,4 +162,3 @@ def get_connection_string_from_fields(self, params: ResolvedConnectionParams) -> def get_connection_string_head(self, params: ResolvedConnectionParams) -> str: return f"{self.url_scheme}://{params.username}:{quote_plus(params.password)}@" - diff --git a/testgen/common/database/flavor/salesforce_data360_flavor_service.py b/testgen/common/database/flavor/salesforce_data360_flavor_service.py new file mode 100644 index 00000000..c5d27d66 --- /dev/null +++ b/testgen/common/database/flavor/salesforce_data360_flavor_service.py @@ -0,0 +1,128 @@ +from typing import Any + +from sqlalchemy.dialects import registry +from sqlalchemy.pool import StaticPool + +from testgen.common.database.column_chars import ColumnChars +from testgen.common.database.flavor.flavor_service import FlavorService, ResolvedConnectionParams + +# Register the dialect so create_engine("salesforce_data360://") works +# without requiring an installed entry point. +registry.register("salesforce_data360", "testgen.common.database.salesforce_data360_dialect", "SalesforceData360Dialect") + +# Mapping from Data 360 metadata types to TestGen general_type codes. +# Data 360's metadata API returns a small fixed vocabulary — these 6 are all that +# have been observed against profiled DMOs and DLOs. Unknown types preserve the +# raw metadata string as column_type and fall through to general_type "X" in +# get_schema_columns(), matching get_schema_ddf.sql behavior for other flavors. +_TYPE_MAP: dict[str, tuple[str, str, bool]] = { + # metadata_type → (column_type, general_type, is_decimal) + "STRING": ("varchar", "A", False), + "NUMBER": ("numeric", "N", True), + "BIGINT": ("bigint", "N", False), + "BOOLEAN": ("boolean", "B", False), + "DATE": ("date", "D", False), + "DATE_TIME": ("datetime", "D", False), +} + + +class SalesforceData360FlavorService(FlavorService): + + concat_operator = "||" + quote_character = '"' + escaped_single_quote = "''" + escaped_underscore = "\\_" + escape_clause = "" + varchar_type = "VARCHAR(1000)" + default_uppercase = False + test_query = "SELECT 1" + url_scheme = "salesforce_data360" + qualifies_table_refs_with_schema = False + metadata_via_api = True + + def get_connection_string(self, _params: ResolvedConnectionParams) -> str: + return "salesforce_data360://" + + def get_connection_string_from_fields(self, _params: ResolvedConnectionParams) -> str: + return "salesforce_data360://" + + def get_connect_args(self, params: ResolvedConnectionParams) -> dict: + # Map Connection model fields to salesforce-cdp-connector kwargs. + # project_host → login_url (org My Domain URL) + # project_user → client_id (Consumer Key from External Client App) + # password → client_secret (Client Credentials flow) + # project_db → username (JWT Bearer flow) + # private_key → private_key (JWT Bearer flow) + # connect_by_key → True = JWT, False = Client Credentials + # table_group_schema → dataspace (Data 360 Data Space — scopes the CDP token) + args: dict[str, Any] = { + "login_url": params.host, + "client_id": params.username, + } + + # Connection-only contexts (Test Connection from the connection wizard) have + # no table group yet, so dbschema is empty — the connector then defaults to + # the org's default Data Space, which is fine for "can we authenticate?". + # Table-group-scoped contexts (profiling, test execution, preview) supply + # the Data Space and the resulting CDP token is restricted to it. + if params.dbschema: + args["dataspace"] = params.dbschema + + if params.connect_by_key and params.private_key: + args["username"] = params.dbname + args["private_key"] = params.private_key + else: + args["client_secret"] = params.password + + return args + + def get_engine_args(self, _params: ResolvedConnectionParams) -> dict[str, Any]: + return { + "pool_pre_ping": False, + "poolclass": StaticPool, + } + + def get_pre_connection_queries(self, _params: ResolvedConnectionParams) -> list[tuple[str, dict | None]]: + return [] + + def get_schema_columns(self, params: ResolvedConnectionParams, schema: str) -> list[ColumnChars] | None: + """Fetch column metadata via the salesforce-cdp-connector metadata API. + + Data 360 has no information_schema — this method replaces the SQL-based + schema discovery for this flavor. + """ + from salesforcecdpconnector.connection import SalesforceCDPConnection + + connect_args = self.get_connect_args(params) + conn = SalesforceCDPConnection(**connect_args) + + try: + tables = conn.list_tables() + finally: + conn.close() + + columns: list[ColumnChars] = [] + for table in tables: + for ordinal, field in enumerate(table.fields, start=1): + if not field.name: + continue + + meta_type = (field.type or "").upper() + mapped = _TYPE_MAP.get(meta_type) + if mapped is not None: + column_type, general_type, is_decimal = mapped + else: + column_type, general_type, is_decimal = meta_type.lower(), "X", False + + columns.append(ColumnChars( + schema_name=schema, + table_name=table.name, + column_name=field.name, + column_type=column_type, + db_data_type=meta_type, + ordinal_position=ordinal, + general_type=general_type, + is_decimal=is_decimal, + )) + + return columns diff --git a/testgen/common/database/salesforce_data360_dialect.py b/testgen/common/database/salesforce_data360_dialect.py new file mode 100644 index 00000000..43b44459 --- /dev/null +++ b/testgen/common/database/salesforce_data360_dialect.py @@ -0,0 +1,169 @@ +"""Minimal SQLAlchemy dialect for Salesforce Data 360. + +Wraps the ``salesforce-cdp-connector`` DB-API 2.0 module so that +SQLAlchemy's ``create_engine`` / ``engine.connect()`` flow works. + +The connector speaks PostgreSQL-compatible SQL (Tableau Hyper engine) +but uses HTTP + OAuth instead of a wire protocol, so we inherit from +``DefaultDialect`` rather than ``PGDialect`` to avoid unwanted +introspection queries. +""" + +import time + +import jwt +from salesforcecdpconnector import authentication_helper as _auth_helper +from salesforcecdpconnector.constants import ( + AUTH_PARAM_ASSERTION, + AUTH_PARAM_CLIENT_CREDENTIALS_GRANT_TYPE, + AUTH_PARAM_CLIENT_ID, + AUTH_PARAM_CLIENT_SECRET, + AUTH_PARAM_GRANT_TYPE, + AUTH_PARAM_JWT_GRANT_TYPE, + AUTH_RESPONSE_ACCESS_TOKEN, + AUTH_RESPONSE_INSTANCE_URL, +) +from salesforcecdpconnector.exceptions import Error as _CdpError +from sqlalchemy.engine.default import DefaultDialect +from sqlalchemy.pool import StaticPool + + +def _format_oauth_failure(grant_label: str, response) -> str: + """Extract Salesforce's ``error`` / ``error_description`` from an OAuth failure. + + The stock connector discards the response body and surfaces only the HTTP + status, which leaves users without an actionable signal (e.g. ``user + hasn't approved this consumer`` vs ``invalid assertion`` vs ``invalid + grant``). This pulls the body fields out so the error reaches the UI. + """ + detail = "" + try: + body = response.json() + description = body.get("error_description") + code = body.get("error") + if description and code: + detail = f": {code} — {description}" + elif description: + detail = f": {description}" + elif code: + detail = f": {code}" + else: + detail = f": {response.text[:300]}" + except ValueError: + if response.text: + detail = f": {response.text[:300]}" + return f"Salesforce {grant_label} authentication failed (HTTP {response.status_code}){detail}" + + +def _token_by_jwt_bearer_flow(self, login_url, username, client_id, private_key): + payload = { + "iss": client_id, + "exp": int(time.time()) + 3600, + "aud": login_url, + "sub": username, + } + encoded = jwt.encode(payload, private_key, algorithm="RS256") + params = {AUTH_PARAM_GRANT_TYPE: AUTH_PARAM_JWT_GRANT_TYPE, AUTH_PARAM_ASSERTION: encoded} + response = self.session.post(url=login_url + "/services/oauth2/token", params=params) + if response.status_code == 200: + access_code = response.json() + return self._exchange_token(access_code[AUTH_RESPONSE_INSTANCE_URL], access_code[AUTH_RESPONSE_ACCESS_TOKEN]) + raise _CdpError(_format_oauth_failure("JWT Bearer", response)) + + +def _token_by_client_creds_flow(self, login_url, client_id, client_secret): + params = { + AUTH_PARAM_GRANT_TYPE: AUTH_PARAM_CLIENT_CREDENTIALS_GRANT_TYPE, + AUTH_PARAM_CLIENT_ID: client_id, + AUTH_PARAM_CLIENT_SECRET: client_secret, + } + response = self.session.post(url=login_url + "/services/oauth2/token", params=params) + if response.status_code == 200: + access_code = response.json() + return self._exchange_token(access_code[AUTH_RESPONSE_INSTANCE_URL], access_code[AUTH_RESPONSE_ACCESS_TOKEN]) + raise _CdpError(_format_oauth_failure("Client Credentials", response)) + + +# Replace the connector's auth methods at import time. The stock methods build +# the same request but throw away the response body on failure. The patched +# methods preserve SF's ``error_description`` so the cause is visible in the +# Test Connection UI and in application logs. +_auth_helper.AuthenticationHelper._token_by_jwt_bearer_flow = _token_by_jwt_bearer_flow +_auth_helper.AuthenticationHelper._token_by_client_creds_flow = _token_by_client_creds_flow + + +class _DBAPIShim: + """Shim module that satisfies SQLAlchemy's ``dialect.dbapi()`` contract. + + SQLAlchemy expects ``dbapi.connect(**kwargs)`` to return a DB-API + connection. We delegate to ``SalesforceCDPConnection``. + """ + + # Re-export the connector's exception hierarchy so SQLAlchemy can + # catch errors through the standard ``dbapi.Error`` path. + from salesforcecdpconnector.exceptions import ( + DatabaseError, + Error, + InterfaceError, + InternalError, + NotSupportedError, + OperationalError, + ProgrammingError, + ) + + paramstyle = "format" # SQLAlchemy needs *some* value; we never actually bind params + + @staticmethod + def connect(**kwargs): + from salesforcecdpconnector.connection import SalesforceCDPConnection + + conn = SalesforceCDPConnection(**kwargs) + # Patch the cursor factory to add missing DB-API attributes + _original_cursor = conn.cursor + + def _patched_cursor(): + cursor = _original_cursor() + if not hasattr(cursor, "rowcount"): + cursor.rowcount = -1 + if not hasattr(cursor, "lastrowid"): + cursor.lastrowid = None + return cursor + + conn.cursor = _patched_cursor + return conn + + +class SalesforceData360Dialect(DefaultDialect): + name = "salesforce_data360" + supports_alter = False + supports_transactions = False + supports_native_boolean = True + supports_statement_cache = False + supports_default_values = False + supports_empty_insert = False + postfetch_lastrowid = False + implicit_returning = False + + @classmethod + def dbapi(cls): + return _DBAPIShim + + @classmethod + def import_dbapi(cls): + return _DBAPIShim + + def create_connect_args(self, _url): + # All auth params arrive via connect_args; the URL is a dummy + # ``salesforce_data360://`` placeholder. + return ([], {}) + + def do_ping(self, _dbapi_connection): + return True + + def initialize(self, connection): + # Skip server-version detection and other introspection that + # DefaultDialect.initialize() performs. + pass + + def get_pool_class(self, _url): + return StaticPool diff --git a/testgen/common/models/connection.py b/testgen/common/models/connection.py index dfb36e71..e209299b 100644 --- a/testgen/common/models/connection.py +++ b/testgen/common/models/connection.py @@ -27,13 +27,14 @@ from testgen.common.models.table_group import TableGroup from testgen.utils import is_uuid4 -SQLFlavorCode = Literal["redshift", "redshift_spectrum", "snowflake", "mssql", "azure_mssql", "synapse_mssql", "postgresql", "databricks", "bigquery", "oracle", "sap_hana"] +SQLFlavorCode = Literal["redshift", "redshift_spectrum", "snowflake", "mssql", "azure_mssql", "synapse_mssql", "postgresql", "databricks", "bigquery", "oracle", "sap_hana", "salesforce_data360"] @dataclass class ConnectionMinimal(EntityMinimal): project_code: str connection_id: int + sql_flavor: SQLFlavor sql_flavor_code: SQLFlavorCode connection_name: str diff --git a/testgen/common/source_data_service.py b/testgen/common/source_data_service.py index c59d0b47..b807f119 100644 --- a/testgen/common/source_data_service.py +++ b/testgen/common/source_data_service.py @@ -265,13 +265,15 @@ def _generate_recency_lookup_query( column_names_str = detail_exp[start_index:] columns = [col.strip() for col in column_names_str.split(",")] - quote = get_flavor_service(sql_flavor).quote_character + flavor_service = get_flavor_service(sql_flavor) + quote = flavor_service.quote_character + table_ref = flavor_service.get_table_ref("{TARGET_SCHEMA}", "{TABLE_NAME}") queries = [ f""" SELECT '{column}' AS column_name, MAX({quote}{column}{quote}) AS max_date_available - FROM {quote}{{TARGET_SCHEMA}}{quote}.{quote}{{TABLE_NAME}}{quote} + FROM {table_ref} """ for column in columns ] diff --git a/testgen/template/dbsetup/030_initialize_new_schema_structure.sql b/testgen/template/dbsetup/030_initialize_new_schema_structure.sql index e77aa9c1..169704bb 100644 --- a/testgen/template/dbsetup/030_initialize_new_schema_structure.sql +++ b/testgen/template/dbsetup/030_initialize_new_schema_structure.sql @@ -72,7 +72,7 @@ CREATE TABLE connections ( sql_flavor_code VARCHAR(30), project_host VARCHAR(250), project_port VARCHAR(5), - project_user VARCHAR(50), + project_user VARCHAR(256), project_db VARCHAR(100), connection_name VARCHAR(40), project_pw_encrypted BYTEA, diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Boolean_Value_Mismatch.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Boolean_Value_Mismatch.yaml index c35be242..c6574673 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Boolean_Value_Mismatch.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Boolean_Value_Mismatch.yaml @@ -101,3 +101,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10047' + test_id: 1015 + test_type: Boolean_Value_Mismatch + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Char_Column_Date_Values.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Char_Column_Date_Values.yaml index a4a44110..8afd6b10 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Char_Column_Date_Values.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Char_Column_Date_Values.yaml @@ -112,3 +112,11 @@ profile_anomaly_types: lookup_query: |- SELECT A.* FROM (SELECT DISTINCT 'Date' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE <%IS_DATE;"{COLUMN_NAME}"%> = 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) A UNION ALL SELECT B.* FROM (SELECT DISTINCT 'Non-Date' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE <%IS_DATE;"{COLUMN_NAME}"%> != 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) B ORDER BY data_type, count DESC error_type: Profile Anomaly + - id: '10048' + test_id: 1012 + test_type: Char_Column_Date_Values + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT A.* FROM ( SELECT DISTINCT 'Date' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE <%IS_DATE;"{COLUMN_NAME}"%> = 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) AS A UNION ALL SELECT B.* FROM ( SELECT DISTINCT 'Non-Date' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE <%IS_DATE;"{COLUMN_NAME}"%> != 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) AS B ORDER BY data_type, count DESC; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Char_Column_Number_Values.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Char_Column_Number_Values.yaml index e23891b6..59504d2f 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Char_Column_Number_Values.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Char_Column_Number_Values.yaml @@ -112,3 +112,11 @@ profile_anomaly_types: lookup_query: |- SELECT A.* FROM (SELECT DISTINCT 'Numeric' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE <%IS_NUM;"{COLUMN_NAME}"%> = 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) A UNION ALL SELECT B.* FROM (SELECT DISTINCT 'Non-Numeric' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE <%IS_NUM;"{COLUMN_NAME}"%> != 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) B ORDER BY data_type, count DESC error_type: Profile Anomaly + - id: '10049' + test_id: 1011 + test_type: Char_Column_Number_Values + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT A.* FROM ( SELECT DISTINCT 'Numeric' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE <%IS_NUM;"{COLUMN_NAME}"%> = 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) AS A UNION ALL SELECT B.* FROM ( SELECT DISTINCT 'Non-Numeric' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE <%IS_NUM;"{COLUMN_NAME}"%> != 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) AS B ORDER BY data_type, count DESC; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Column_Pattern_Mismatch.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Column_Pattern_Mismatch.yaml index 87441e8a..2e228e69 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Column_Pattern_Mismatch.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Column_Pattern_Mismatch.yaml @@ -99,7 +99,7 @@ profile_anomaly_types: sql_flavor: postgresql lookup_type: null lookup_query: |- - SELECT A.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 4)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( "{COLUMN_NAME}", '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) A UNION ALL SELECT B.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 6)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( "{COLUMN_NAME}", '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) B UNION ALL SELECT C.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 8)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( "{COLUMN_NAME}", '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) C UNION ALL SELECT D.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 10)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( "{COLUMN_NAME}", '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) D ORDER BY top_pattern DESC, count DESC; + SELECT A.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 4)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( CAST("{COLUMN_NAME}" AS VARCHAR), '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) A UNION ALL SELECT B.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 6)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( CAST("{COLUMN_NAME}" AS VARCHAR), '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) B UNION ALL SELECT C.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 8)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( CAST("{COLUMN_NAME}" AS VARCHAR), '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) C UNION ALL SELECT D.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 10)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( CAST("{COLUMN_NAME}" AS VARCHAR), '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) D ORDER BY top_pattern DESC, count DESC; error_type: Profile Anomaly - id: '1039' test_id: '1007' @@ -141,3 +141,11 @@ profile_anomaly_types: lookup_query: |- SELECT A.* FROM (SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT TRIM(SUBSTR_REGEXPR('[^|]+' IN '{DETAIL_EXPRESSION}' OCCURRENCE 4)) AS top_pattern FROM DUMMY) b WHERE REPLACE_REGEXPR('[0-9]' IN REPLACE_REGEXPR('[A-Z]' IN REPLACE_REGEXPR('[a-z]' IN "{COLUMN_NAME}" WITH 'a') WITH 'A') WITH 'N') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) A UNION ALL SELECT B.* FROM (SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT TRIM(SUBSTR_REGEXPR('[^|]+' IN '{DETAIL_EXPRESSION}' OCCURRENCE 6)) AS top_pattern FROM DUMMY) b WHERE REPLACE_REGEXPR('[0-9]' IN REPLACE_REGEXPR('[A-Z]' IN REPLACE_REGEXPR('[a-z]' IN "{COLUMN_NAME}" WITH 'a') WITH 'A') WITH 'N') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) B UNION ALL SELECT C.* FROM (SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT TRIM(SUBSTR_REGEXPR('[^|]+' IN '{DETAIL_EXPRESSION}' OCCURRENCE 8)) AS top_pattern FROM DUMMY) b WHERE REPLACE_REGEXPR('[0-9]' IN REPLACE_REGEXPR('[A-Z]' IN REPLACE_REGEXPR('[a-z]' IN "{COLUMN_NAME}" WITH 'a') WITH 'A') WITH 'N') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) C UNION ALL SELECT D.* FROM (SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT TRIM(SUBSTR_REGEXPR('[^|]+' IN '{DETAIL_EXPRESSION}' OCCURRENCE 10)) AS top_pattern FROM DUMMY) b WHERE REPLACE_REGEXPR('[0-9]' IN REPLACE_REGEXPR('[A-Z]' IN REPLACE_REGEXPR('[a-z]' IN "{COLUMN_NAME}" WITH 'a') WITH 'A') WITH 'N') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) D ORDER BY top_pattern DESC, count DESC error_type: Profile Anomaly + - id: '10050' + test_id: 1007 + test_type: Column_Pattern_Mismatch + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT A.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 4)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( CAST("{COLUMN_NAME}" AS VARCHAR), '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) A UNION ALL SELECT B.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 6)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( CAST("{COLUMN_NAME}" AS VARCHAR), '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) B UNION ALL SELECT C.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 8)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( CAST("{COLUMN_NAME}" AS VARCHAR), '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) C UNION ALL SELECT D.* FROM ( SELECT DISTINCT b.top_pattern, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}", (SELECT trim(split_part('{DETAIL_EXPRESSION}', '|', 10)) AS top_pattern) b WHERE REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( CAST("{COLUMN_NAME}" AS VARCHAR), '[a-z]', 'a', 'g'), '[A-Z]', 'A', 'g'), '[0-9]', 'N', 'g') = b.top_pattern GROUP BY b.top_pattern, "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_4}) D ORDER BY top_pattern DESC, count DESC; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Delimited_Data_Embedded.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Delimited_Data_Embedded.yaml index 570ed5ad..ef4853e6 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Delimited_Data_Embedded.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Delimited_Data_Embedded.yaml @@ -96,3 +96,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" LIKE_REGEXPR '^([^,|' || NCHAR(9) || ']{1,20}[,|' || NCHAR(9) || ']){2,}[^,|' || NCHAR(9) || ']{0,20}([,|' || NCHAR(9) || ']{0,1}[^,|' || NCHAR(9) || ']{0,20})*$' AND NOT "{COLUMN_NAME}" LIKE_REGEXPR '[[:space:]](and|but|or|yet)[[:space:]]' GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10051' + test_id: 1025 + test_type: Delimited_Data_Embedded + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE REGEXP_LIKE(CAST("{COLUMN_NAME}" AS VARCHAR), '^([^,|\t]{1,20}[,|\t]){2,}[^,|\t]{0,20}([,|\t]{0,1}[^,|\t]{0,20})*$') AND NOT REGEXP_LIKE(CAST("{COLUMN_NAME}" AS VARCHAR), '\s(and|but|or|yet)\s') GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Inconsistent_Casing.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Inconsistent_Casing.yaml index 176a3565..6c8d7156 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Inconsistent_Casing.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Inconsistent_Casing.yaml @@ -141,3 +141,17 @@ profile_anomaly_types: lookup_query: |- SELECT * FROM (SELECT 'Upper Case' as casing, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE UPPER("{COLUMN_NAME}") = "{COLUMN_NAME}" GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT_2}) UNION ALL SELECT * FROM (SELECT 'Mixed Case' as casing, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" <> UPPER("{COLUMN_NAME}") AND "{COLUMN_NAME}" <> LOWER("{COLUMN_NAME}") GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT_2}) error_type: Profile Anomaly + - id: '10052' + test_id: 1028 + test_type: Inconsistent_Casing + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + (SELECT 'Upper Case' as casing, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" + WHERE UPPER("{COLUMN_NAME}") = "{COLUMN_NAME}" + GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT_2}) + UNION ALL + (SELECT 'Mixed Case' as casing, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" + WHERE "{COLUMN_NAME}" <> UPPER("{COLUMN_NAME}") AND "{COLUMN_NAME}" <> LOWER("{COLUMN_NAME}") + GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT_2}) + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip3_USA.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip3_USA.yaml index ed042ca9..b3c9a750 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip3_USA.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip3_USA.yaml @@ -98,3 +98,11 @@ profile_anomaly_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE REPLACE_REGEXPR('[0-9]' IN "{COLUMN_NAME}" WITH '9') <> '999' GROUP BY "{COLUMN_NAME}" ORDER BY count DESC, "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10053' + test_id: 1024 + test_type: Invalid_Zip3_USA + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE REGEXP_REPLACE(CAST("{COLUMN_NAME}" AS VARCHAR), '[0-8]', '9', 'g') <> '999' GROUP BY "{COLUMN_NAME}" ORDER BY count DESC, "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip_USA.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip_USA.yaml index 2e13f4a8..e9b094ec 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip_USA.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip_USA.yaml @@ -94,3 +94,11 @@ profile_anomaly_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE REPLACE_REGEXPR('[0-9]' IN "{COLUMN_NAME}" WITH '9') NOT IN ('99999', '999999999', '99999-9999') GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10054' + test_id: 1003 + test_type: Invalid_Zip_USA + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE REGEXP_REPLACE(CAST("{COLUMN_NAME}" AS VARCHAR), '[0-8]', '9', 'g') NOT IN ('99999', '999999999', '99999-9999') GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Leading_Spaces.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Leading_Spaces.yaml index d63a84e1..05ee4c0c 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Leading_Spaces.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Leading_Spaces.yaml @@ -94,3 +94,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE (CASE WHEN "{COLUMN_NAME}" BETWEEN ' !' AND '!' THEN 1 ELSE 0 END) = 1 GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10055' + test_id: 1009 + test_type: Leading_Spaces + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE (CASE WHEN "{COLUMN_NAME}" BETWEEN ' !' AND '!' THEN 1 ELSE 0 END) = 1 GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Multiple_Types_Major.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Multiple_Types_Major.yaml index cb4fa797..3633cc99 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Multiple_Types_Major.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Multiple_Types_Major.yaml @@ -108,3 +108,10 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT COLUMN_NAME, TABLE_NAME, CASE WHEN DATA_TYPE_NAME LIKE 'TIMESTAMP%%' THEN LOWER(DATA_TYPE_NAME) WHEN DATA_TYPE_NAME = 'DATE' THEN 'date' WHEN DATA_TYPE_NAME IN ('NVARCHAR', 'VARCHAR') THEN LOWER(DATA_TYPE_NAME) || '(' || LENGTH || ')' WHEN DATA_TYPE_NAME = 'CHAR' THEN 'char(' || LENGTH || ')' WHEN DATA_TYPE_NAME = 'DECIMAL' AND SCALE = 0 THEN 'decimal(' || LENGTH || ')' WHEN DATA_TYPE_NAME = 'DECIMAL' THEN 'decimal(' || LENGTH || ',' || SCALE || ')' WHEN DATA_TYPE_NAME IN ('INTEGER', 'BIGINT', 'SMALLINT', 'TINYINT') THEN LOWER(DATA_TYPE_NAME) ELSE LOWER(DATA_TYPE_NAME) END AS data_type FROM SYS.TABLE_COLUMNS WHERE SCHEMA_NAME = '{TARGET_SCHEMA}' AND COLUMN_NAME = '{COLUMN_NAME}' ORDER BY data_type, TABLE_NAME LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10056' + test_id: 1005 + test_type: Multiple_Types_Major + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: null + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Multiple_Types_Minor.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Multiple_Types_Minor.yaml index 80e9fa80..eb1555ad 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Multiple_Types_Minor.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Multiple_Types_Minor.yaml @@ -108,3 +108,10 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT COLUMN_NAME, TABLE_NAME, CASE WHEN DATA_TYPE_NAME LIKE 'TIMESTAMP%%' THEN LOWER(DATA_TYPE_NAME) WHEN DATA_TYPE_NAME = 'DATE' THEN 'date' WHEN DATA_TYPE_NAME IN ('NVARCHAR', 'VARCHAR') THEN LOWER(DATA_TYPE_NAME) || '(' || LENGTH || ')' WHEN DATA_TYPE_NAME = 'CHAR' THEN 'char(' || LENGTH || ')' WHEN DATA_TYPE_NAME = 'DECIMAL' AND SCALE = 0 THEN 'decimal(' || LENGTH || ')' WHEN DATA_TYPE_NAME = 'DECIMAL' THEN 'decimal(' || LENGTH || ',' || SCALE || ')' WHEN DATA_TYPE_NAME IN ('INTEGER', 'BIGINT', 'SMALLINT', 'TINYINT') THEN LOWER(DATA_TYPE_NAME) ELSE LOWER(DATA_TYPE_NAME) END AS data_type FROM SYS.TABLE_COLUMNS WHERE SCHEMA_NAME = '{TARGET_SCHEMA}' AND COLUMN_NAME = '{COLUMN_NAME}' ORDER BY data_type, TABLE_NAME LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10057' + test_id: 1004 + test_type: Multiple_Types_Minor + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: null + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_No_Values.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_No_Values.yaml index 46c6f955..7f3835b4 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_No_Values.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_No_Values.yaml @@ -96,3 +96,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10058' + test_id: 1006 + test_type: No_Values + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Alpha_Name_Address.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Alpha_Name_Address.yaml index 820e6423..29d1aff7 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Alpha_Name_Address.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Alpha_Name_Address.yaml @@ -108,3 +108,13 @@ profile_anomaly_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) as record_ct FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" = UPPER("{COLUMN_NAME}") AND "{COLUMN_NAME}" = LOWER("{COLUMN_NAME}") AND "{COLUMN_NAME}" IS NOT NULL GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10059' + test_id: 1029 + test_type: Non_Alpha_Name_Address + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) as record_ct FROM "{TABLE_NAME}" + WHERE "{COLUMN_NAME}" = UPPER("{COLUMN_NAME}") AND "{COLUMN_NAME}" = LOWER("{COLUMN_NAME}") AND "{COLUMN_NAME}" > '' + GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Alpha_Prefixed_Name.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Alpha_Prefixed_Name.yaml index 22ed1cd9..78c96148 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Alpha_Prefixed_Name.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Alpha_Prefixed_Name.yaml @@ -110,3 +110,13 @@ profile_anomaly_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) as record_ct FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" < 'A' AND SUBSTR("{COLUMN_NAME}", 1, 1) NOT IN ('"', ' ') AND SUBSTR("{COLUMN_NAME}", -1, 1) <> '''' GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10060' + test_id: 1030 + test_type: Non_Alpha_Prefixed_Name + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) as record_ct FROM "{TABLE_NAME}" + WHERE "{COLUMN_NAME}" < 'A' AND SUBSTR("{COLUMN_NAME}", 1, 1) NOT IN ('"', ' ') AND SUBSTRING("{COLUMN_NAME}", LENGTH("{COLUMN_NAME}")) <> '''' + GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Printing_Chars.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Printing_Chars.yaml index 34821875..e8291e15 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Printing_Chars.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Printing_Chars.yaml @@ -161,3 +161,34 @@ profile_anomaly_types: lookup_query: |- SELECT REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE("{COLUMN_NAME}", NCHAR(160), '\x160'), NCHAR(8201), '\x8201'), NCHAR(8203), '\x8203'), NCHAR(8204), '\x8204'), NCHAR(8205), '\x8205'), NCHAR(8206), '\x8206'), NCHAR(8207), '\x8207'), NCHAR(8239), '\x8239'), NCHAR(12288), '\x12288'), NCHAR(65279), '\x65279') as "{COLUMN_NAME}", COUNT(*) as record_ct FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE("{COLUMN_NAME}", NCHAR(160), ''), NCHAR(8201), ''), NCHAR(8203), ''), NCHAR(8204), ''), NCHAR(8205), ''), NCHAR(8206), ''), NCHAR(8207), ''), NCHAR(8239), ''), NCHAR(12288), ''), NCHAR(65279), '') <> "{COLUMN_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10061' + test_id: 1031 + test_type: Non_Printing_Chars + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE("{COLUMN_NAME}", + CHR(160), '\x160'), + CHR(8201), '\x8201'), + CHR(8203), '\x8203'), + CHR(8204), '\x8204'), + CHR(8205), '\x8205'), + CHR(8206), '\x8206'), + CHR(8207), '\x8207'), + CHR(8239), '\x8239'), + CHR(12288), '\x12288'), + CHR(65279), '\x65279') as "{COLUMN_NAME}", + COUNT(*) as record_ct FROM "{TABLE_NAME}" + WHERE REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE("{COLUMN_NAME}", + CHR(160), ''), + CHR(8201), ''), + CHR(8203), ''), + CHR(8204), ''), + CHR(8205), ''), + CHR(8206), ''), + CHR(8207), ''), + CHR(8239), ''), + CHR(12288), ''), + CHR(65279), '') <> "{COLUMN_NAME}" + GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Standard_Blanks.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Standard_Blanks.yaml index 4e1c104b..e6bfb600 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Standard_Blanks.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Non_Standard_Blanks.yaml @@ -107,3 +107,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE CASE WHEN "{COLUMN_NAME}" IN ('.', '?', ' ') THEN 1 WHEN LOWER("{COLUMN_NAME}") LIKE_REGEXPR '(-{2,}|0{2,}|9{2,}|x{2,}|z{2,})' THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('blank','error','missing','tbd', 'n/a','#na','none','null','unknown') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('(blank)','(error)','(missing)','(tbd)', '(n/a)','(#na)','(none)','(null)','(unknown)') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('[blank]','[error]','[missing]','[tbd]', '[n/a]','[#na]','[none]','[null]','[unknown]') THEN 1 WHEN "{COLUMN_NAME}" IS NULL THEN 1 ELSE 0 END = 1 GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10062' + test_id: 1002 + test_type: Non_Standard_Blanks + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE CASE WHEN "{COLUMN_NAME}" IN ('.', '?', ' ') THEN 1 WHEN REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), '-{2,}') OR REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), '0{2,}') OR REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), '9{2,}') OR REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), 'x{2,}') OR REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), 'z{2,}') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('blank','error','missing','tbd', 'n/a','#na','none','null','unknown') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('(blank)','(error)','(missing)','(tbd)', '(n/a)','(#na)','(none)','(null)','(unknown)') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('[blank]','[error]','[missing]','[tbd]', '[n/a]','[#na]','[none]','[null]','[unknown]') THEN 1 WHEN "{COLUMN_NAME}" = '' THEN 1 WHEN "{COLUMN_NAME}" IS NULL THEN 1 ELSE 0 END = 1 GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Potential_Duplicates.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Potential_Duplicates.yaml index 46383270..613f5571 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Potential_Duplicates.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Potential_Duplicates.yaml @@ -96,3 +96,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" HAVING COUNT(*) > 1 ORDER BY COUNT(*) DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10063' + test_id: 1016 + test_type: Potential_Duplicates + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" HAVING COUNT(*)> 1 ORDER BY COUNT(*) DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Potential_PII.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Potential_PII.yaml index c33bfae9..492062d9 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Potential_PII.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Potential_PII.yaml @@ -94,3 +94,11 @@ profile_anomaly_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10064' + test_id: 1100 + test_type: Potential_PII + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Quoted_Values.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Quoted_Values.yaml index b7ac31bc..a315b5ed 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Quoted_Values.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Quoted_Values.yaml @@ -95,3 +95,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE (CASE WHEN "{COLUMN_NAME}" LIKE '"%%"' OR "{COLUMN_NAME}" LIKE '''%%''' THEN 1 ELSE 0 END) = 1 GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10065' + test_id: 1010 + test_type: Quoted_Values + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE (CASE WHEN "{COLUMN_NAME}" LIKE '"%"' OR "{COLUMN_NAME}" LIKE '''%''' THEN 1 ELSE 0 END) = 1 GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Recency_One_Year.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Recency_One_Year.yaml index d24286ca..40e4c02f 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Recency_One_Year.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Recency_One_Year.yaml @@ -90,3 +90,10 @@ profile_anomaly_types: lookup_query: |- created_in_ui error_type: Profile Anomaly + - id: '10066' + test_id: 1019 + test_type: Recency_One_Year + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: created_in_ui + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Recency_Six_Months.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Recency_Six_Months.yaml index a94f7474..0eafe386 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Recency_Six_Months.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Recency_Six_Months.yaml @@ -90,3 +90,10 @@ profile_anomaly_types: lookup_query: |- created_in_ui error_type: Profile Anomaly + - id: '10067' + test_id: 1020 + test_type: Recency_Six_Months + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: created_in_ui + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Divergent_Value_Ct.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Divergent_Value_Ct.yaml index 25c6065a..b8e5ec4d 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Divergent_Value_Ct.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Divergent_Value_Ct.yaml @@ -87,3 +87,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10068' + test_id: 1014 + test_type: Small Divergent Value Ct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Missing_Value_Ct.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Missing_Value_Ct.yaml index b8093ab0..3e6fb266 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Missing_Value_Ct.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Missing_Value_Ct.yaml @@ -90,3 +90,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE (CASE WHEN "{COLUMN_NAME}" IN ('.', '?', ' ') THEN 1 WHEN LOWER("{COLUMN_NAME}") LIKE_REGEXPR '(-{2,}|0{2,}|9{2,}|x{2,}|z{2,})' THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('blank','error','missing','tbd', 'n/a','#na','none','null','unknown') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('(blank)','(error)','(missing)','(tbd)', '(n/a)','(#na)','(none)','(null)','(unknown)') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('[blank]','[error]','[missing]','[tbd]', '[n/a]','[#na]','[none]','[null]','[unknown]') THEN 1 WHEN "{COLUMN_NAME}" IS NULL THEN 1 ELSE 0 END) = 1 GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10069' + test_id: 1013 + test_type: Small Missing Value Ct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE (CASE WHEN "{COLUMN_NAME}" IN ('.', '?', ' ') THEN 1 WHEN REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), '-{2,}') OR REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), '0{2,}') OR REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), '9{2,}') OR REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), 'x{2,}') OR REGEXP_LIKE(LOWER(CAST("{COLUMN_NAME}" AS VARCHAR)), 'z{2,}') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('blank','error','missing','tbd', 'n/a','#na','none','null','unknown') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('(blank)','(error)','(missing)','(tbd)', '(n/a)','(#na)','(none)','(null)','(unknown)') THEN 1 WHEN LOWER("{COLUMN_NAME}") IN ('[blank]','[error]','[missing]','[tbd]', '[n/a]','[#na]','[none]','[null]','[unknown]') THEN 1 WHEN "{COLUMN_NAME}" = '' THEN 1 WHEN "{COLUMN_NAME}" IS NULL THEN 1 ELSE 0 END) = 1 GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Numeric_Value_Ct.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Numeric_Value_Ct.yaml index 0b868784..5249c9d6 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Numeric_Value_Ct.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Small_Numeric_Value_Ct.yaml @@ -109,3 +109,11 @@ profile_anomaly_types: lookup_query: |- SELECT A.* FROM (SELECT DISTINCT 'Numeric' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE <%IS_NUM;"{COLUMN_NAME}"%> = 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) A UNION ALL SELECT B.* FROM (SELECT DISTINCT 'Non-Numeric' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE <%IS_NUM;"{COLUMN_NAME}"%> != 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) B ORDER BY data_type, count DESC error_type: Profile Anomaly + - id: '10070' + test_id: 1023 + test_type: Small_Numeric_Value_Ct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT A.* FROM ( SELECT DISTINCT 'Numeric' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE <%IS_NUM;"{COLUMN_NAME}"%> = 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) AS A UNION ALL SELECT B.* FROM ( SELECT DISTINCT 'Non-Numeric' as data_type, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE <%IS_NUM;"{COLUMN_NAME}"%> != 1 GROUP BY "{COLUMN_NAME}" ORDER BY count DESC LIMIT {LIMIT_2}) AS B ORDER BY data_type, count DESC; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Standardized_Value_Matches.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Standardized_Value_Matches.yaml index 870862a4..4e4f43ad 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Standardized_Value_Matches.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Standardized_Value_Matches.yaml @@ -104,3 +104,11 @@ profile_anomaly_types: lookup_query: |- WITH CTE AS ( SELECT DISTINCT UPPER(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE("{COLUMN_NAME}", ' ', ''), '''', ''), ',', ''), '.', ''), '-', '')) as possible_standard_value, COUNT(DISTINCT "{COLUMN_NAME}") AS cnt FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY UPPER(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE("{COLUMN_NAME}", ' ', ''), '''', ''), ',', ''), '.', ''), '-', '')) HAVING COUNT(DISTINCT "{COLUMN_NAME}") > 1 ) SELECT a."{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" a, cte b WHERE UPPER(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(a."{COLUMN_NAME}", ' ', ''), '''', ''), ',', ''), '.', ''), '-', '')) = b.possible_standard_value GROUP BY a."{COLUMN_NAME}" ORDER BY UPPER(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(a."{COLUMN_NAME}", ' ', ''), '''', ''), ',', ''), '.', ''), '-', '')) ASC, count DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10071' + test_id: 1017 + test_type: Standardized_Value_Matches + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + WITH CTE AS ( SELECT DISTINCT UPPER(REGEXP_REPLACE("{COLUMN_NAME}", '[ '',.\-]', '', 'g')) as possible_standard_value, COUNT(DISTINCT "{COLUMN_NAME}") FROM "{TABLE_NAME}" GROUP BY UPPER(REGEXP_REPLACE("{COLUMN_NAME}", '[ '',.\-]', '', 'g')) HAVING COUNT(DISTINCT "{COLUMN_NAME}") > 1 ) SELECT a."{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" a, cte b WHERE UPPER(REGEXP_REPLACE(a."{COLUMN_NAME}", '[ '',.\-]', '', 'g')) = b.possible_standard_value GROUP BY a."{COLUMN_NAME}" ORDER BY UPPER(REGEXP_REPLACE(a."{COLUMN_NAME}", '[ '',.\-]', '', 'g')) ASC, count DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Suggested_Type.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Suggested_Type.yaml index b623888b..812c6b95 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Suggested_Type.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Suggested_Type.yaml @@ -95,3 +95,11 @@ profile_anomaly_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS record_ct FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY record_ct DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10072' + test_id: 1001 + test_type: Suggested_Type + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) AS record_ct FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY record_ct DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Table_Pattern_Mismatch.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Table_Pattern_Mismatch.yaml index d72d9875..7be19eb1 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Table_Pattern_Mismatch.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Table_Pattern_Mismatch.yaml @@ -105,3 +105,10 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT COLUMN_NAME, TABLE_NAME FROM SYS.TABLE_COLUMNS WHERE SCHEMA_NAME = '{TARGET_SCHEMA}' AND COLUMN_NAME = '{COLUMN_NAME}' ORDER BY TABLE_NAME LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10073' + test_id: 1008 + test_type: Table_Pattern_Mismatch + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: null + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unexpected_Emails.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unexpected_Emails.yaml index 9c9dd4f8..f939cf78 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unexpected_Emails.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unexpected_Emails.yaml @@ -95,3 +95,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10074' + test_id: 1022 + test_type: Unexpected Emails + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unexpected_US_States.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unexpected_US_States.yaml index b86117ab..669c5259 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unexpected_US_States.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unexpected_US_States.yaml @@ -97,3 +97,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10075' + test_id: 1021 + test_type: Unexpected US States + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unlikely_Date_Values.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unlikely_Date_Values.yaml index c5f9c540..f75eb93f 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unlikely_Date_Values.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Unlikely_Date_Values.yaml @@ -99,3 +99,11 @@ profile_anomaly_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", TO_DATE('{PROFILE_RUN_DATE}', 'YYYY-MM-DD') AS profile_run_date, COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" a WHERE ("{COLUMN_NAME}" < TO_DATE('1900-01-01', 'YYYY-MM-DD')) OR ("{COLUMN_NAME}" > ADD_MONTHS(TO_DATE('{PROFILE_RUN_DATE}', 'YYYY-MM-DD'), 360)) GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10076' + test_id: 1018 + test_type: Unlikely_Date_Values + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", CAST('{PROFILE_RUN_DATE}' AS DATE) AS profile_run_date, COUNT(*) AS count FROM "{TABLE_NAME}" a WHERE ("{COLUMN_NAME}" < CAST('1900-01-01' AS DATE)) OR ("{COLUMN_NAME}" > CAST('{PROFILE_RUN_DATE}' AS DATE) + INTERVAL '30 year' ) GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Variant_Coded_Values.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Variant_Coded_Values.yaml index 72265501..e252f4c7 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Variant_Coded_Values.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Variant_Coded_Values.yaml @@ -98,3 +98,11 @@ profile_anomaly_types: lookup_query: |- WITH val_list(token, remaining) AS ( SELECT CASE WHEN LOCATE(SUBSTR('{DETAIL_EXPRESSION}', LOCATE('{DETAIL_EXPRESSION}', ':') + 2), '|') > 0 THEN TRIM(SUBSTR(SUBSTR('{DETAIL_EXPRESSION}', LOCATE('{DETAIL_EXPRESSION}', ':') + 2), 1, LOCATE(SUBSTR('{DETAIL_EXPRESSION}', LOCATE('{DETAIL_EXPRESSION}', ':') + 2), '|') - 1)) ELSE TRIM(SUBSTR('{DETAIL_EXPRESSION}', LOCATE('{DETAIL_EXPRESSION}', ':') + 2)) END AS token, CASE WHEN LOCATE(SUBSTR('{DETAIL_EXPRESSION}', LOCATE('{DETAIL_EXPRESSION}', ':') + 2), '|') > 0 THEN SUBSTR(SUBSTR('{DETAIL_EXPRESSION}', LOCATE('{DETAIL_EXPRESSION}', ':') + 2), LOCATE(SUBSTR('{DETAIL_EXPRESSION}', LOCATE('{DETAIL_EXPRESSION}', ':') + 2), '|') + 1) ELSE '' END AS remaining FROM DUMMY UNION ALL SELECT CASE WHEN LOCATE(remaining, '|') > 0 THEN TRIM(SUBSTR(remaining, 1, LOCATE(remaining, '|') - 1)) ELSE TRIM(remaining) END AS token, CASE WHEN LOCATE(remaining, '|') > 0 THEN SUBSTR(remaining, LOCATE(remaining, '|') + 1) ELSE '' END AS remaining FROM val_list WHERE LENGTH(remaining) > 0 ) SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE LOWER("{COLUMN_NAME}") IN (SELECT token FROM val_list) GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Profile Anomaly + - id: '10077' + test_id: 1027 + test_type: Variant_Coded_Values + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE LOWER("{COLUMN_NAME}") IN (SELECT TRIM(val) FROM UNNEST(STRING_TO_ARRAY(SUBSTRING('{DETAIL_EXPRESSION}', STRPOS('{DETAIL_EXPRESSION}', ':') + 2), '|')) AS t(val)) GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Profile Anomaly diff --git a/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance.yaml b/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance.yaml index 89882477..cb98ab7e 100644 --- a/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance.yaml @@ -266,6 +266,31 @@ test_types: ORDER BY {GROUPBY_NAMES} LIMIT {LIMIT} error_type: Test Results + - id: '10001' + test_id: 1500 + test_type: Aggregate_Balance + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT * + FROM ( SELECT {GROUPBY_NAMES}, SUM(TOTAL) AS total, SUM(MATCH_TOTAL) AS MATCH_TOTAL + FROM + ( SELECT {GROUPBY_NAMES}, {COLUMN_NAME_NO_QUOTES} AS total, NULL AS match_total + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + {HAVING_CONDITION} + UNION ALL + SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total + FROM "{MATCH_TABLE_NAME}" + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} ) a + GROUP BY {GROUPBY_NAMES} ) s + WHERE total <> match_total OR (total IS NOT NULL AND match_total IS NULL) OR (total IS NULL AND match_total IS NOT NULL) + ORDER BY {GROUPBY_NAMES} + LIMIT {LIMIT}; + error_type: Test Results test_templates: - id: '2506' test_type: Aggregate_Balance @@ -698,3 +723,52 @@ test_types: WHERE total <> match_total OR (total IS NOT NULL AND match_total IS NULL) OR (total IS NULL AND match_total IS NOT NULL) + - id: '10001' + test_type: Aggregate_Balance + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + NULL as result_signal, + CASE WHEN COUNT (*) > {SKIP_ERRORS} THEN 0 ELSE 1 END as result_code, + CASE + WHEN COUNT(*) > 0 THEN + CONCAT( + CONCAT( CAST(COUNT(*) AS {VARCHAR_TYPE}), ' error(s) identified, ' ), + CONCAT( + CASE + WHEN COUNT(*) > {SKIP_ERRORS} THEN 'exceeding limit of ' + ELSE 'within limit of ' + END, + '{SKIP_ERRORS}.' + ) + ) + ELSE 'No errors found.' + END AS result_message, + COUNT(*) as result_measure + FROM ( SELECT {GROUPBY_NAMES}, SUM(TOTAL) as total, SUM(MATCH_TOTAL) as MATCH_TOTAL + FROM + ( SELECT {GROUPBY_NAMES}, {COLUMN_NAME_NO_QUOTES} as total, NULL as match_total + FROM {QUOTE}{TABLE_NAME}{QUOTE} + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + {HAVING_CONDITION} + UNION ALL + SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total + FROM {QUOTE}{MATCH_TABLE_NAME}{QUOTE} + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} ) a + GROUP BY {GROUPBY_NAMES} ) s + WHERE total <> match_total + OR (total IS NOT NULL AND match_total IS NULL) + OR (total IS NULL AND match_total IS NOT NULL); diff --git a/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Percent.yaml b/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Percent.yaml index b15b0114..3b0d81d7 100644 --- a/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Percent.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Percent.yaml @@ -284,6 +284,33 @@ test_types: ORDER BY {GROUPBY_NAMES} LIMIT {LIMIT} error_type: Test Results + - id: '10002' + test_id: 1504 + test_type: Aggregate_Balance_Percent + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT * + FROM ( SELECT {GROUPBY_NAMES}, SUM(TOTAL) AS total, SUM(MATCH_TOTAL) AS MATCH_TOTAL + FROM + ( SELECT {GROUPBY_NAMES}, {COLUMN_NAME_NO_QUOTES} AS total, NULL AS match_total + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + {HAVING_CONDITION} + UNION ALL + SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total + FROM "{MATCH_TABLE_NAME}" + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} ) a + GROUP BY {GROUPBY_NAMES} ) s + WHERE (total IS NOT NULL AND match_total IS NULL) + OR (total IS NULL AND match_total IS NOT NULL) + OR (total NOT BETWEEN match_total * (1 + {LOWER_TOLERANCE}/100.0) AND match_total * (1 + {UPPER_TOLERANCE}/100.0)) + ORDER BY {GROUPBY_NAMES} + LIMIT {LIMIT}; + error_type: Test Results test_templates: - id: '2509' test_type: Aggregate_Balance_Percent @@ -716,3 +743,52 @@ test_types: WHERE (total IS NOT NULL AND match_total IS NULL) OR (total IS NULL AND match_total IS NOT NULL) OR (total NOT BETWEEN match_total * (1 + {LOWER_TOLERANCE}/100.0) AND match_total * (1 + {UPPER_TOLERANCE}/100.0)) + - id: '10002' + test_type: Aggregate_Balance_Percent + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + NULL as result_signal, + CASE WHEN COUNT (*) > {SKIP_ERRORS} THEN 0 ELSE 1 END as result_code, + CASE + WHEN COUNT(*) > 0 THEN + CONCAT( + CONCAT( CAST(COUNT(*) AS {VARCHAR_TYPE}), ' error(s) identified, ' ), + CONCAT( + CASE + WHEN COUNT(*) > {SKIP_ERRORS} THEN 'exceeding limit of ' + ELSE 'within limit of ' + END, + '{SKIP_ERRORS}.' + ) + ) + ELSE 'No errors found.' + END AS result_message, + COUNT(*) as result_measure + FROM ( SELECT {GROUPBY_NAMES}, SUM(TOTAL) as total, SUM(MATCH_TOTAL) as MATCH_TOTAL + FROM + ( SELECT {GROUPBY_NAMES}, {COLUMN_NAME_NO_QUOTES} as total, NULL as match_total + FROM {QUOTE}{TABLE_NAME}{QUOTE} + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + {HAVING_CONDITION} + UNION ALL + SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total + FROM {QUOTE}{MATCH_TABLE_NAME}{QUOTE} + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} ) a + GROUP BY {GROUPBY_NAMES} ) s + WHERE (total IS NOT NULL AND match_total IS NULL) + OR (total IS NULL AND match_total IS NOT NULL) + OR (total NOT BETWEEN match_total * (1 + {LOWER_TOLERANCE}/100.0) AND match_total * (1 + {UPPER_TOLERANCE}/100.0)); diff --git a/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Range.yaml b/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Range.yaml index 1fe4cdc4..2fc50146 100644 --- a/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Range.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Range.yaml @@ -284,6 +284,33 @@ test_types: ORDER BY {GROUPBY_NAMES} LIMIT {LIMIT} error_type: Test Results + - id: '10003' + test_id: 1505 + test_type: Aggregate_Balance_Range + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT * + FROM ( SELECT {GROUPBY_NAMES}, SUM(TOTAL) AS total, SUM(MATCH_TOTAL) AS MATCH_TOTAL + FROM + ( SELECT {GROUPBY_NAMES}, {COLUMN_NAME_NO_QUOTES} AS total, NULL AS match_total + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + {HAVING_CONDITION} + UNION ALL + SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total + FROM "{MATCH_TABLE_NAME}" + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} ) a + GROUP BY {GROUPBY_NAMES} ) s + WHERE (total IS NOT NULL AND match_total IS NULL) + OR (total IS NULL AND match_total IS NOT NULL) + OR (total NOT BETWEEN match_total + {LOWER_TOLERANCE} AND match_total + {UPPER_TOLERANCE}) + ORDER BY {GROUPBY_NAMES} + LIMIT {LIMIT}; + error_type: Test Results test_templates: - id: '2510' test_type: Aggregate_Balance_Range @@ -716,3 +743,52 @@ test_types: WHERE (total IS NOT NULL AND match_total IS NULL) OR (total IS NULL AND match_total IS NOT NULL) OR (total NOT BETWEEN match_total + {LOWER_TOLERANCE} AND match_total + {UPPER_TOLERANCE}) + - id: '10003' + test_type: Aggregate_Balance_Range + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + NULL as result_signal, + CASE WHEN COUNT (*) > {SKIP_ERRORS} THEN 0 ELSE 1 END as result_code, + CASE + WHEN COUNT(*) > 0 THEN + CONCAT( + CONCAT( CAST(COUNT(*) AS {VARCHAR_TYPE}), ' error(s) identified, ' ), + CONCAT( + CASE + WHEN COUNT(*) > {SKIP_ERRORS} THEN 'exceeding limit of ' + ELSE 'within limit of ' + END, + '{SKIP_ERRORS}.' + ) + ) + ELSE 'No errors found.' + END AS result_message, + COUNT(*) as result_measure + FROM ( SELECT {GROUPBY_NAMES}, SUM(TOTAL) as total, SUM(MATCH_TOTAL) as MATCH_TOTAL + FROM + ( SELECT {GROUPBY_NAMES}, {COLUMN_NAME_NO_QUOTES} as total, NULL as match_total + FROM {QUOTE}{TABLE_NAME}{QUOTE} + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + {HAVING_CONDITION} + UNION ALL + SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total + FROM {QUOTE}{MATCH_TABLE_NAME}{QUOTE} + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} ) a + GROUP BY {GROUPBY_NAMES} ) s + WHERE (total IS NOT NULL AND match_total IS NULL) + OR (total IS NULL AND match_total IS NOT NULL) + OR (total NOT BETWEEN match_total + {LOWER_TOLERANCE} AND match_total + {UPPER_TOLERANCE}); diff --git a/testgen/template/dbsetup_test_types/test_types_Aggregate_Minimum.yaml b/testgen/template/dbsetup_test_types/test_types_Aggregate_Minimum.yaml index 8607dec0..cd35e549 100644 --- a/testgen/template/dbsetup_test_types/test_types_Aggregate_Minimum.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Aggregate_Minimum.yaml @@ -266,6 +266,31 @@ test_types: ORDER BY {GROUPBY_NAMES} LIMIT {LIMIT} error_type: Test Results + - id: '10004' + test_id: 1501 + test_type: Aggregate_Minimum + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT * + FROM ( SELECT {GROUPBY_NAMES}, SUM(TOTAL) as total, SUM(MATCH_TOTAL) as MATCH_TOTAL + FROM + ( SELECT {GROUPBY_NAMES}, {COLUMN_NAME_NO_QUOTES} as total, NULL as match_total + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + {HAVING_CONDITION} + UNION ALL + SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total + FROM "{MATCH_TABLE_NAME}" + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} ) a + GROUP BY {GROUPBY_NAMES} ) s + WHERE total < match_total OR (total IS NULL AND match_total IS NOT NULL) + ORDER BY {GROUPBY_NAMES} + LIMIT {LIMIT}; + error_type: Test Results test_templates: - id: '2502' test_type: Aggregate_Minimum @@ -698,3 +723,52 @@ test_types: WHERE total < match_total -- OR (total IS NOT NULL AND match_total IS NULL) -- New categories OR (total IS NULL AND match_total IS NOT NULL) + - id: '10004' + test_type: Aggregate_Minimum + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + NULL as result_signal, + CASE WHEN COUNT (*) > {SKIP_ERRORS} THEN 0 ELSE 1 END as result_code, + CASE + WHEN COUNT(*) > 0 THEN + CONCAT( + CONCAT( CAST(COUNT(*) AS {VARCHAR_TYPE}), ' error(s) identified, ' ), + CONCAT( + CASE + WHEN COUNT(*) > {SKIP_ERRORS} THEN 'exceeding limit of ' + ELSE 'within limit of ' + END, + '{SKIP_ERRORS}.' + ) + ) + ELSE 'No errors found.' + END AS result_message, + COUNT(*) as result_measure + FROM ( SELECT {GROUPBY_NAMES}, SUM(TOTAL) as total, SUM(MATCH_TOTAL) as MATCH_TOTAL + FROM + ( SELECT {GROUPBY_NAMES}, {COLUMN_NAME_NO_QUOTES} as total, NULL as match_total + FROM {QUOTE}{TABLE_NAME}{QUOTE} + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + {HAVING_CONDITION} + UNION ALL + SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total + FROM {QUOTE}{MATCH_TABLE_NAME}{QUOTE} + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} ) a + GROUP BY {GROUPBY_NAMES} ) s + WHERE total < match_total + -- OR (total IS NOT NULL AND match_total IS NULL) -- New categories + OR (total IS NULL AND match_total IS NOT NULL); -- Dropped categories diff --git a/testgen/template/dbsetup_test_types/test_types_Alpha_Trunc.yaml b/testgen/template/dbsetup_test_types/test_types_Alpha_Trunc.yaml index 41ab1ab7..88577f60 100644 --- a/testgen/template/dbsetup_test_types/test_types_Alpha_Trunc.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Alpha_Trunc.yaml @@ -117,6 +117,14 @@ test_types: test_operator: < test_condition: |- {THRESHOLD_VALUE} + - id: '10001' + test_type: Alpha_Trunc + sql_flavor: salesforce_data360 + measure: |- + MAX(LENGTH({COLUMN_NAME})) + test_operator: '<' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1364' test_id: '1004' @@ -197,4 +205,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", LENGTH("{COLUMN_NAME}") as current_max_length, {THRESHOLD_VALUE} as previous_max_length FROM "{TARGET_SCHEMA}"."{TABLE_NAME}", (SELECT MAX(LENGTH("{COLUMN_NAME}")) as max_length FROM "{TARGET_SCHEMA}"."{TABLE_NAME}") a WHERE LENGTH("{COLUMN_NAME}") = a.max_length AND a.max_length < {THRESHOLD_VALUE} LIMIT {LIMIT} error_type: Test Results + - id: '10005' + test_id: 1004 + test_type: Alpha_Trunc + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", LENGTH("{COLUMN_NAME}") as current_max_length, {THRESHOLD_VALUE} as previous_max_length FROM "{TABLE_NAME}", (SELECT MAX(LENGTH("{COLUMN_NAME}")) as max_length FROM "{TABLE_NAME}") a WHERE LENGTH("{COLUMN_NAME}") = a.max_length AND a.max_length < {THRESHOLD_VALUE} LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Avg_Shift.yaml b/testgen/template/dbsetup_test_types/test_types_Avg_Shift.yaml index 49a3c5b9..08f801a7 100644 --- a/testgen/template/dbsetup_test_types/test_types_Avg_Shift.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Avg_Shift.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>=' test_condition: |- {THRESHOLD_VALUE} + - id: '10002' + test_type: Avg_Shift + sql_flavor: salesforce_data360 + measure: |- + ABS( (AVG(CAST({COLUMN_NAME} AS FLOAT)) - {BASELINE_AVG}) / SQRT(((CAST(COUNT({COLUMN_NAME}) AS FLOAT)-1)*POWER(STDDEV(CAST({COLUMN_NAME} AS FLOAT)),2) + (CAST({BASELINE_VALUE_CT} AS FLOAT)-1) * POWER(CAST({BASELINE_SD} AS FLOAT),2)) /NULLIF(CAST(COUNT({COLUMN_NAME}) AS FLOAT) + CAST({BASELINE_VALUE_CT} AS FLOAT), 0) )) + test_operator: '>=' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1365' test_id: '1005' @@ -192,4 +200,12 @@ test_types: lookup_query: |- SELECT AVG(CAST("{COLUMN_NAME}" AS DECIMAL)) AS current_average FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" error_type: Test Results + - id: '10006' + test_id: 1005 + test_type: Avg_Shift + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT AVG(CAST("{COLUMN_NAME}" AS FLOAT)) AS current_average FROM "{TABLE_NAME}"; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_CUSTOM.yaml b/testgen/template/dbsetup_test_types/test_types_CUSTOM.yaml index 3257b114..6005806c 100644 --- a/testgen/template/dbsetup_test_types/test_types_CUSTOM.yaml +++ b/testgen/template/dbsetup_test_types/test_types_CUSTOM.yaml @@ -382,3 +382,42 @@ test_types: FROM ( {CUSTOM_QUERY} ) TEST + - id: '10005' + test_type: CUSTOM + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + CASE + WHEN '{COLUMN_NAME_NO_QUOTES}' = '' OR '{COLUMN_NAME_NO_QUOTES}' IS NULL THEN NULL + ELSE '{COLUMN_NAME_NO_QUOTES}' + END as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + /* TODO: 'custom_query= {CUSTOM_QUERY_ESCAPED}' as input_parameters, */ + 'Skip_Errors={SKIP_ERRORS}' as input_parameters, + NULL as result_signal, + CASE WHEN COUNT (*) > {SKIP_ERRORS} THEN 0 ELSE 1 END as result_code, + CASE + WHEN COUNT(*) > 0 THEN + CONCAT( + CONCAT( CAST(COUNT(*) AS {VARCHAR_TYPE}), ' error(s) identified, ' ), + CONCAT( + CASE + WHEN COUNT(*) > {SKIP_ERRORS} THEN 'exceeding limit of ' + ELSE 'within limit of ' + END, + '{SKIP_ERRORS}.' + ) + ) + ELSE 'No errors found.' + END AS result_message, + COUNT(*) as result_measure + FROM ( + {CUSTOM_QUERY} + ) TEST; diff --git a/testgen/template/dbsetup_test_types/test_types_Combo_Match.yaml b/testgen/template/dbsetup_test_types/test_types_Combo_Match.yaml index cdc5bfde..3b027325 100644 --- a/testgen/template/dbsetup_test_types/test_types_Combo_Match.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Combo_Match.yaml @@ -237,6 +237,28 @@ test_types: ORDER BY {COLUMN_NAME_NO_QUOTES} LIMIT {LIMIT} error_type: Test Results + - id: '10007' + test_id: 1502 + test_type: Combo_Match + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT * + FROM ( SELECT {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + GROUP BY {COLUMN_NAME_NO_QUOTES} + {HAVING_CONDITION} + EXCEPT + SELECT {MATCH_GROUPBY_NAMES} + FROM "{MATCH_TABLE_NAME}" + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} + ) test + ORDER BY {COLUMN_NAME_NO_QUOTES} + LIMIT {LIMIT}; + error_type: Test Results test_templates: - id: '2501' test_type: Combo_Match @@ -626,3 +648,47 @@ test_types: GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) test + - id: '10006' + test_type: Combo_Match + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + NULL as result_signal, + CASE WHEN COUNT (*) > {SKIP_ERRORS} THEN 0 ELSE 1 END as result_code, + CASE + WHEN COUNT(*) > 0 THEN + CONCAT( + CONCAT( CAST(COUNT(*) AS {VARCHAR_TYPE}), ' error(s) identified, ' ), + CONCAT( + CASE + WHEN COUNT(*) > {SKIP_ERRORS} THEN 'exceeding limit of ' + ELSE 'within limit of ' + END, + '{SKIP_ERRORS}.' + ) + ) + ELSE 'No errors found.' + END AS result_message, + COUNT(*) as result_measure + FROM ( SELECT {COLUMN_NAME_NO_QUOTES} + FROM {QUOTE}{TABLE_NAME}{QUOTE} + WHERE {SUBSET_CONDITION} + GROUP BY {COLUMN_NAME_NO_QUOTES} + {HAVING_CONDITION} + EXCEPT + SELECT {MATCH_GROUPBY_NAMES} + FROM {QUOTE}{MATCH_TABLE_NAME}{QUOTE} + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} + {MATCH_HAVING_CONDITION} + ) test; diff --git a/testgen/template/dbsetup_test_types/test_types_Condition_Flag.yaml b/testgen/template/dbsetup_test_types/test_types_Condition_Flag.yaml index 733ef0b5..9c63f169 100644 --- a/testgen/template/dbsetup_test_types/test_types_Condition_Flag.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Condition_Flag.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10003' + test_type: Condition_Flag + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN {CUSTOM_QUERY} THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1366' test_id: '1006' @@ -194,4 +202,12 @@ test_types: lookup_query: |- SELECT * FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE {CUSTOM_QUERY} LIMIT {LIMIT} error_type: Test Results + - id: '10008' + test_id: 1006 + test_type: Condition_Flag + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT * FROM "{TABLE_NAME}" WHERE {CUSTOM_QUERY} LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Constant.yaml b/testgen/template/dbsetup_test_types/test_types_Constant.yaml index 2bb8e6df..fff389cd 100644 --- a/testgen/template/dbsetup_test_types/test_types_Constant.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Constant.yaml @@ -117,6 +117,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10004' + test_type: Constant + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN {COLUMN_NAME} <> {BASELINE_VALUE} THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1367' test_id: '1007' @@ -194,4 +202,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" <> {BASELINE_VALUE} GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Test Results + - id: '10009' + test_id: 1007 + test_type: Constant + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE "{COLUMN_NAME}" <> {BASELINE_VALUE} GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Daily_Record_Ct.yaml b/testgen/template/dbsetup_test_types/test_types_Daily_Record_Ct.yaml index fb9fe8bb..df372cd6 100644 --- a/testgen/template/dbsetup_test_types/test_types_Daily_Record_Ct.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Daily_Record_Ct.yaml @@ -121,6 +121,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10005' + test_type: Daily_Record_Ct + sql_flavor: salesforce_data360 + measure: |- + DATEDIFF('day', CAST(MIN({COLUMN_NAME}) AS DATE), CAST(MAX({COLUMN_NAME}) AS DATE))+1-COUNT(DISTINCT CAST({COLUMN_NAME} AS DATE)) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1368' test_id: '1009' @@ -263,4 +271,12 @@ test_types: lookup_query: |- WITH Pass0 AS (SELECT 1 C FROM DUMMY UNION ALL SELECT 1 FROM DUMMY), Pass1 AS (SELECT 1 C FROM Pass0 A, Pass0 B), Pass2 AS (SELECT 1 C FROM Pass1 A, Pass1 B), Pass3 AS (SELECT 1 C FROM Pass2 A, Pass2 B), Pass4 AS (SELECT 1 C FROM Pass3 A, Pass3 B), nums AS (SELECT ROW_NUMBER() OVER (ORDER BY C) - 1 AS rn FROM Pass4), bounds AS (SELECT MIN(CAST("{COLUMN_NAME}" AS DATE)) AS min_date, MAX(CAST("{COLUMN_NAME}" AS DATE)) AS max_date FROM "{TARGET_SCHEMA}"."{TABLE_NAME}"), daterange AS (SELECT ADD_DAYS(b.min_date, n.rn) AS all_dates FROM bounds b, nums n WHERE ADD_DAYS(b.min_date, n.rn) <= b.max_date), existing_periods AS (SELECT DISTINCT CAST("{COLUMN_NAME}" AS DATE) AS period, COUNT(1) AS period_count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY CAST("{COLUMN_NAME}" AS DATE)) SELECT p.missing_period, p.prior_available_date, e.period_count AS prior_available_date_count, p.next_available_date, f.period_count AS next_available_date_count FROM (SELECT d.all_dates AS missing_period, MAX(b.period) AS prior_available_date, MIN(c.period) AS next_available_date FROM daterange d LEFT JOIN existing_periods a ON d.all_dates = a.period LEFT JOIN existing_periods b ON b.period < d.all_dates LEFT JOIN existing_periods c ON c.period > d.all_dates WHERE a.period IS NULL AND d.all_dates BETWEEN b.period AND c.period GROUP BY d.all_dates) p LEFT JOIN existing_periods e ON (p.prior_available_date = e.period) LEFT JOIN existing_periods f ON (p.next_available_date = f.period) ORDER BY p.missing_period LIMIT {LIMIT} error_type: Test Results + - id: '10010' + test_id: 1009 + test_type: Daily_Record_Ct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + WITH RECURSIVE daterange(all_dates) AS (SELECT CAST(MIN("{COLUMN_NAME}") AS DATE) AS all_dates FROM "{TABLE_NAME}" UNION ALL SELECT CAST((d.all_dates + INTERVAL '1 day') AS DATE) AS all_dates FROM daterange d WHERE d.all_dates < (SELECT CAST(MAX("{COLUMN_NAME}") AS DATE) FROM "{TABLE_NAME}") ), existing_periods AS ( SELECT DISTINCT CAST("{COLUMN_NAME}" AS DATE) AS period, COUNT(1) AS period_count FROM "{TABLE_NAME}" GROUP BY CAST("{COLUMN_NAME}" AS DATE) ) SELECT p.missing_period, p.prior_available_date, e.period_count AS prior_available_date_count, p.next_available_date, f.period_count AS next_available_date_count FROM (SELECT d.all_dates AS missing_period, MAX(b.period) AS prior_available_date, MIN(c.period) AS next_available_date FROM daterange d LEFT JOIN existing_periods a ON d.all_dates = a.period LEFT JOIN existing_periods b ON b.period < d.all_dates LEFT JOIN existing_periods c ON c.period > d.all_dates WHERE a.period IS NULL AND d.all_dates BETWEEN b.period AND c.period GROUP BY d.all_dates) p LEFT JOIN existing_periods e ON (p.prior_available_date = e.period) LEFT JOIN existing_periods f ON (p.next_available_date = f.period) ORDER BY p.missing_period LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Dec_Trunc.yaml b/testgen/template/dbsetup_test_types/test_types_Dec_Trunc.yaml index e717d8fb..770d1175 100644 --- a/testgen/template/dbsetup_test_types/test_types_Dec_Trunc.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Dec_Trunc.yaml @@ -118,6 +118,14 @@ test_types: test_operator: < test_condition: |- {THRESHOLD_VALUE} + - id: '10006' + test_type: Dec_Trunc + sql_flavor: salesforce_data360 + measure: |- + SUM(ROUND(ABS(MOD({COLUMN_NAME}, 1)), 5))+1 + test_operator: '<' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1369' test_id: '1011' @@ -199,4 +207,12 @@ test_types: lookup_query: |- SELECT DISTINCT CASE WHEN LOCATE(TO_VARCHAR("{COLUMN_NAME}"), '.') > 0 THEN LENGTH(SUBSTR(TO_VARCHAR("{COLUMN_NAME}"), LOCATE(TO_VARCHAR("{COLUMN_NAME}"), '.') + 1)) ELSE 0 END AS decimal_scale, COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY CASE WHEN LOCATE(TO_VARCHAR("{COLUMN_NAME}"), '.') > 0 THEN LENGTH(SUBSTR(TO_VARCHAR("{COLUMN_NAME}"), LOCATE(TO_VARCHAR("{COLUMN_NAME}"), '.') + 1)) ELSE 0 END LIMIT {LIMIT} error_type: Test Results + - id: '10011' + test_id: 1011 + test_type: Dec_Trunc + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT LENGTH(SPLIT_PART(CAST("{COLUMN_NAME}" AS TEXT), '.', 2)) AS decimal_scale, COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY decimal_scale LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Distinct_Date_Ct.yaml b/testgen/template/dbsetup_test_types/test_types_Distinct_Date_Ct.yaml index 4ddc1dd4..23967398 100644 --- a/testgen/template/dbsetup_test_types/test_types_Distinct_Date_Ct.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Distinct_Date_Ct.yaml @@ -118,6 +118,14 @@ test_types: test_operator: < test_condition: |- {THRESHOLD_VALUE} + - id: '10007' + test_type: Distinct_Date_Ct + sql_flavor: salesforce_data360 + measure: |- + COUNT(DISTINCT {COLUMN_NAME}) + test_operator: '<' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1370' test_id: '1012' @@ -196,4 +204,12 @@ test_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NOT NULL GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT} error_type: Test Results + - id: '10012' + test_id: 1012 + test_type: Distinct_Date_Ct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NOT NULL GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Distinct_Value_Ct.yaml b/testgen/template/dbsetup_test_types/test_types_Distinct_Value_Ct.yaml index e7737220..47e609ec 100644 --- a/testgen/template/dbsetup_test_types/test_types_Distinct_Value_Ct.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Distinct_Value_Ct.yaml @@ -117,6 +117,14 @@ test_types: test_operator: <> test_condition: |- {THRESHOLD_VALUE} + - id: '10008' + test_type: Distinct_Value_Ct + sql_flavor: salesforce_data360 + measure: |- + COUNT(DISTINCT {COLUMN_NAME}) + test_operator: '<>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1371' test_id: '1013' @@ -195,4 +203,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NOT NULL GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT} error_type: Test Results + - id: '10013' + test_id: 1013 + test_type: Distinct_Value_Ct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NOT NULL GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Distribution_Shift.yaml b/testgen/template/dbsetup_test_types/test_types_Distribution_Shift.yaml index 627cd8a3..7a3b2361 100644 --- a/testgen/template/dbsetup_test_types/test_types_Distribution_Shift.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Distribution_Shift.yaml @@ -74,7 +74,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {MATCH_SCHEMA_NAME}.{TABLE_NAME} v2 + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ) SELECT COALESCE(l.category, o.category) AS category, @@ -102,7 +102,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, CAST(COUNT(*) as FLOAT) / CAST(SUM(COUNT(*)) OVER () as FLOAT) AS pct_of_total - FROM {MATCH_SCHEMA_NAME}.{TABLE_NAME} v2 + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ) SELECT TOP {LIMIT} COALESCE(l.category, o.category) AS category, @@ -129,7 +129,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {MATCH_SCHEMA_NAME}.{TABLE_NAME} v2 + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ) SELECT COALESCE(l.category, o.category) AS category, @@ -157,7 +157,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {MATCH_SCHEMA_NAME}.{TABLE_NAME} v2 + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ) SELECT COALESCE(l.category, o.category) AS category, @@ -185,7 +185,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {MATCH_SCHEMA_NAME}.{TABLE_NAME} v2 + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ) SELECT COALESCE(l.category, o.category) AS category, @@ -213,7 +213,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {MATCH_SCHEMA_NAME}.{TABLE_NAME} v2 + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ) SELECT COALESCE(l.category, o.category) AS category, @@ -241,7 +241,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, CAST(COUNT(*) AS NUMBER) / CAST(SUM(COUNT(*)) OVER () AS NUMBER) AS pct_of_total - FROM {MATCH_SCHEMA_NAME}.{TABLE_NAME} v2 + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ) SELECT COALESCE(l.category, o.category) AS category, @@ -269,7 +269,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, CAST(COUNT(*) AS DECIMAL) / CAST(SUM(COUNT(*)) OVER () AS DECIMAL) AS pct_of_total - FROM {MATCH_SCHEMA_NAME}.{TABLE_NAME} v2 + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ) SELECT COALESCE(l.category, o.category) AS category, @@ -281,6 +281,33 @@ test_types: ORDER BY COALESCE(l.category, o.category) LIMIT {LIMIT} error_type: Test Results + - id: '10014' + test_id: 1503 + test_type: Distribution_Shift + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + WITH latest_ver + AS ( SELECT {CONCAT_COLUMNS} as category, + CAST(COUNT(*) AS DOUBLE) / CAST(SUM(COUNT(*)) OVER () AS DOUBLE) AS pct_of_total + FROM "{TABLE_NAME}" v1 + WHERE {SUBSET_CONDITION} + GROUP BY {COLUMN_NAME_NO_QUOTES} ), + older_ver + AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, + CAST(COUNT(*) AS DOUBLE) / CAST(SUM(COUNT(*)) OVER () AS DOUBLE) AS pct_of_total + FROM "{MATCH_TABLE_NAME}" v2 + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} ) + SELECT COALESCE(l.category, o.category) AS category, + o.pct_of_total AS old_pct, + l.pct_of_total AS new_pct + FROM latest_ver l + FULL JOIN older_ver o + ON (l.category = o.category) + ORDER BY COALESCE(l.category, o.category) + LIMIT {LIMIT}; + error_type: Test Results test_templates: - id: '2503' test_type: Distribution_Shift @@ -302,7 +329,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} AS category, CAST(COUNT(*) AS FLOAT64) / CAST(SUM(COUNT(*)) OVER () AS FLOAT64) AS pct_of_total - FROM `{MATCH_SCHEMA_NAME}.{TABLE_NAME}` v2 + FROM `{MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME}` v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ), @@ -355,7 +382,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{TABLE_NAME}{QUOTE} v2 + FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{MATCH_TABLE_NAME}{QUOTE} v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ), dataset @@ -408,7 +435,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, CAST(COUNT(*) as FLOAT) / CAST(SUM(COUNT(*)) OVER () as FLOAT) AS pct_of_total - FROM "{MATCH_SCHEMA_NAME}"."{TABLE_NAME}" v2 + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ), dataset @@ -461,7 +488,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{TABLE_NAME}{QUOTE} v2 + FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{MATCH_TABLE_NAME}{QUOTE} v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ), dataset @@ -514,7 +541,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{TABLE_NAME}{QUOTE} v2 + FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{MATCH_TABLE_NAME}{QUOTE} v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ), dataset @@ -567,7 +594,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{TABLE_NAME}{QUOTE} v2 + FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{MATCH_TABLE_NAME}{QUOTE} v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ), dataset @@ -620,7 +647,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, COUNT(*)::FLOAT / SUM(COUNT(*)) OVER ()::FLOAT AS pct_of_total - FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{TABLE_NAME}{QUOTE} v2 + FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{MATCH_TABLE_NAME}{QUOTE} v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ), dataset @@ -673,7 +700,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, CAST(COUNT(*) AS NUMBER) / CAST(SUM(COUNT(*)) OVER () AS NUMBER) AS pct_of_total - FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{TABLE_NAME}{QUOTE} v2 + FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{MATCH_TABLE_NAME}{QUOTE} v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ), dataset @@ -724,7 +751,7 @@ test_types: older_ver AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, CAST(COUNT(*) AS DECIMAL) / CAST(SUM(COUNT(*)) OVER () AS DECIMAL) AS pct_of_total - FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{TABLE_NAME}{QUOTE} v2 + FROM {QUOTE}{MATCH_SCHEMA_NAME}{QUOTE}.{QUOTE}{MATCH_TABLE_NAME}{QUOTE} v2 WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} ), dataset @@ -756,3 +783,56 @@ test_types: SELECT 0.5 * ABS(SUM(new_pct * LN(new_pct/avg_pct)/LN(2))) + 0.5 * ABS(SUM(old_pct * LN(old_pct/avg_pct)/LN(2))) as js_divergence FROM dataset ) rslt + - id: '10007' + test_type: Distribution_Shift + sql_flavor: salesforce_data360 + template: |- + -- Relative Entropy: measured by Jensen-Shannon Divergence + -- Smoothed and normalized version of KL divergence, + -- with scores between 0 (identical) and 1 (maximally different), + -- when using the base-2 logarithm. Formula is: + -- 0.5 * kl_divergence(p, m) + 0.5 * kl_divergence(q, m) + -- Log base 2 of x = LN(x)/LN(2) + WITH latest_ver + AS ( SELECT {CONCAT_COLUMNS} as category, + CAST(COUNT(*) AS DOUBLE) / CAST(SUM(COUNT(*)) OVER () AS DOUBLE) AS pct_of_total + FROM {QUOTE}{TABLE_NAME}{QUOTE} v1 + WHERE {SUBSET_CONDITION} + GROUP BY {COLUMN_NAME_NO_QUOTES} ), + older_ver + AS ( SELECT {CONCAT_MATCH_GROUPBY} as category, + CAST(COUNT(*) AS DOUBLE) / CAST(SUM(COUNT(*)) OVER () AS DOUBLE) AS pct_of_total + FROM {QUOTE}{MATCH_TABLE_NAME}{QUOTE} v2 + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} ), + dataset + AS ( SELECT COALESCE(l.category, o.category) AS category, + COALESCE(o.pct_of_total, 0.0000001) AS old_pct, + COALESCE(l.pct_of_total, 0.0000001) AS new_pct, + (COALESCE(o.pct_of_total, 0.0000001) + + COALESCE(l.pct_of_total, 0.0000001))/2.0 AS avg_pct + FROM latest_ver l + FULL JOIN older_ver o + ON (l.category = o.category) ) + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + -- '{GROUPBY_NAMES}' as column_names, + '{THRESHOLD_VALUE}' as threshold_value, + NULL as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + NULL as result_signal, + CASE WHEN js_divergence > {THRESHOLD_VALUE} THEN 0 ELSE 1 END as result_code, + CONCAT('Divergence Level: ', + CONCAT(CAST(js_divergence AS {VARCHAR_TYPE}), + ', Threshold: {THRESHOLD_VALUE}.')) as result_message, + js_divergence as result_measure + FROM ( + SELECT 0.5 * ABS(SUM(new_pct * LN(new_pct/avg_pct)/LN(2))) + + 0.5 * ABS(SUM(old_pct * LN(old_pct/avg_pct)/LN(2))) as js_divergence + FROM dataset ) rslt; diff --git a/testgen/template/dbsetup_test_types/test_types_Dupe_Rows.yaml b/testgen/template/dbsetup_test_types/test_types_Dupe_Rows.yaml index 57c778cc..1ef27125 100644 --- a/testgen/template/dbsetup_test_types/test_types_Dupe_Rows.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Dupe_Rows.yaml @@ -165,6 +165,20 @@ test_types: ORDER BY {GROUPBY_NAMES} LIMIT {LIMIT} error_type: Test Results + - id: '10015' + test_id: 1510 + test_type: Dupe_Rows + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT {GROUPBY_NAMES}, COUNT(*) as record_ct + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + HAVING COUNT(*) > 1 + ORDER BY {GROUPBY_NAMES} + LIMIT {LIMIT}; + error_type: Test Results test_templates: - id: '2511' test_type: Dupe_Rows @@ -499,3 +513,41 @@ test_types: GROUP BY {GROUPBY_NAMES} HAVING COUNT(*) > 1 ) test + - id: '10008' + test_type: Dupe_Rows + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + NULL as result_signal, + CASE WHEN COUNT (*) > {SKIP_ERRORS} THEN 0 ELSE 1 END as result_code, + CASE + WHEN COUNT(*) > 0 THEN + CONCAT( + CONCAT( CAST(COUNT(*) AS {VARCHAR_TYPE}), ' duplicate row(s) identified, ' ), + CONCAT( + CASE + WHEN COUNT(*) > {SKIP_ERRORS} THEN 'exceeding limit of ' + ELSE 'within limit of ' + END, + '{SKIP_ERRORS}.' + ) + ) + ELSE 'No errors found.' + END AS result_message, + COALESCE(SUM(record_ct), 0) as result_measure + FROM ( SELECT {GROUPBY_NAMES}, COUNT(*) as record_ct + FROM {QUOTE}{TABLE_NAME}{QUOTE} + WHERE {SUBSET_CONDITION} + GROUP BY {GROUPBY_NAMES} + HAVING COUNT(*) > 1 + ) test; diff --git a/testgen/template/dbsetup_test_types/test_types_Email_Format.yaml b/testgen/template/dbsetup_test_types/test_types_Email_Format.yaml index ab0a8704..6d6573b4 100644 --- a/testgen/template/dbsetup_test_types/test_types_Email_Format.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Email_Format.yaml @@ -117,6 +117,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10009' + test_type: Email_Format + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN NOT REGEXP_LIKE(CAST({COLUMN_NAME} AS VARCHAR), '^[A-Za-z0-9._''%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}$') THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1372' test_id: '1014' @@ -194,4 +202,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE NOT "{COLUMN_NAME}" LIKE_REGEXPR '^[A-Za-z0-9._''%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}$' GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Test Results + - id: '10016' + test_id: 1014 + test_type: Email_Format + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE NOT REGEXP_LIKE(CAST("{COLUMN_NAME}" AS VARCHAR), '^[A-Za-z0-9._''%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}$') GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Freshness_Trend.yaml b/testgen/template/dbsetup_test_types/test_types_Freshness_Trend.yaml index ba60e7c5..c8d1e3a2 100644 --- a/testgen/template/dbsetup_test_types/test_types_Freshness_Trend.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Freshness_Trend.yaml @@ -492,3 +492,53 @@ test_types: ELSE COALESCE(TO_VARCHAR(interval_minutes), 'Unknown') END AS result_signal FROM test_data; + - id: '10009' + test_type: Freshness_Trend + sql_flavor: salesforce_data360 + template: |- + WITH test_data AS ( + SELECT + MD5({CUSTOM_QUERY}) AS fingerprint, + DATEDIFF('minute', CAST(NULLIF('{BASELINE_SUM}', '') AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP)) AS interval_minutes + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + ) + SELECT '{TEST_TYPE}' AS test_type, + '{TEST_DEFINITION_ID}' AS test_definition_id, + '{TEST_SUITE_ID}' AS test_suite_id, + '{TEST_RUN_ID}' AS test_run_id, + '{RUN_DATE}' AS test_time, + '{SCHEMA_NAME}' AS schema_name, + '{TABLE_NAME}' AS table_name, + '{COLUMN_NAME_NO_QUOTES}' AS column_names, + '{SKIP_ERRORS}' AS threshold_value, + {SKIP_ERRORS} AS skip_errors, + '{INPUT_PARAMETERS}' AS input_parameters, + fingerprint AS result_measure, + CASE + -- Training mode: tolerances not yet calculated + WHEN {LOWER_TOLERANCE} IS NULL AND {UPPER_TOLERANCE} IS NULL THEN -1 + -- No change and excluded day: suppress + WHEN fingerprint = '{BASELINE_VALUE}' AND {IS_EXCLUDED_DAY} = 1 THEN 1 + -- No change, beyond time range (business time): LATE + WHEN fingerprint = '{BASELINE_VALUE}' + AND (interval_minutes - {EXCLUDED_MINUTES}) > {THRESHOLD_VALUE} THEN 0 + -- Table changed outside time range (business time): UNEXPECTED + WHEN fingerprint <> '{BASELINE_VALUE}' + AND NOT (interval_minutes - {EXCLUDED_MINUTES}) + BETWEEN {LOWER_TOLERANCE} AND {UPPER_TOLERANCE} THEN 0 + ELSE 1 + END AS result_code, + 'Table update detected: ' || CASE WHEN fingerprint <> '{BASELINE_VALUE}' THEN 'Yes' ELSE 'No' END + || CASE + WHEN fingerprint <> '{BASELINE_VALUE}' AND (interval_minutes - {EXCLUDED_MINUTES}) BETWEEN {LOWER_TOLERANCE} AND {UPPER_TOLERANCE} THEN '. On time.' + WHEN fingerprint <> '{BASELINE_VALUE}' AND (interval_minutes - {EXCLUDED_MINUTES}) < {LOWER_TOLERANCE} THEN '. Earlier than expected.' + WHEN fingerprint <> '{BASELINE_VALUE}' AND (interval_minutes - {EXCLUDED_MINUTES}) > {UPPER_TOLERANCE} THEN '. Later than expected.' + WHEN fingerprint = '{BASELINE_VALUE}' AND {IS_EXCLUDED_DAY} = 0 AND (interval_minutes - {EXCLUDED_MINUTES}) > {THRESHOLD_VALUE} THEN '. Late.' + ELSE '' + END AS result_message, + CASE + WHEN fingerprint <> '{BASELINE_VALUE}' THEN '0' + ELSE COALESCE(CAST(interval_minutes AS VARCHAR), 'Unknown') + END AS result_signal + FROM test_data; diff --git a/testgen/template/dbsetup_test_types/test_types_Future_Date.yaml b/testgen/template/dbsetup_test_types/test_types_Future_Date.yaml index 938091da..c2327843 100644 --- a/testgen/template/dbsetup_test_types/test_types_Future_Date.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Future_Date.yaml @@ -116,6 +116,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10010' + test_type: Future_Date + sql_flavor: salesforce_data360 + measure: |- + SUM(GREATEST(0, SIGN(CAST({COLUMN_NAME} AS DATE) - CAST('{RUN_DATE}' AS DATE)))) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1373' test_id: '1015' @@ -193,4 +201,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE CAST("{COLUMN_NAME}" AS DATE) > TO_DATE('{TEST_DATE}', 'YYYY-MM-DD HH24:MI:SS') GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Test Results + - id: '10017' + test_id: 1015 + test_type: Future_Date + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE GREATEST(0, SIGN(CAST("{COLUMN_NAME}" AS DATE) - CAST('{TEST_DATE}' AS DATE))) > {THRESHOLD_VALUE} GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Future_Date_1Y.yaml b/testgen/template/dbsetup_test_types/test_types_Future_Date_1Y.yaml index 01a42a83..c1ec7d6d 100644 --- a/testgen/template/dbsetup_test_types/test_types_Future_Date_1Y.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Future_Date_1Y.yaml @@ -117,6 +117,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10011' + test_type: Future_Date_1Y + sql_flavor: salesforce_data360 + measure: |- + SUM(GREATEST(0, SIGN(CAST({COLUMN_NAME} AS DATE) - (CAST('{RUN_DATE}' AS DATE)+365)))) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1374' test_id: '1016' @@ -194,4 +202,12 @@ test_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE CAST("{COLUMN_NAME}" AS DATE) > ADD_DAYS(TO_DATE('{TEST_DATE}', 'YYYY-MM-DD HH24:MI:SS'), 365) GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Test Results + - id: '10018' + test_id: 1016 + test_type: Future_Date_1Y + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE GREATEST(0, SIGN(CAST("{COLUMN_NAME}" AS DATE) - (CAST('{TEST_DATE}' AS DATE) + 365))) > {THRESHOLD_VALUE} GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Incr_Avg_Shift.yaml b/testgen/template/dbsetup_test_types/test_types_Incr_Avg_Shift.yaml index eddb6227..00b6b9f6 100644 --- a/testgen/template/dbsetup_test_types/test_types_Incr_Avg_Shift.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Incr_Avg_Shift.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>=' test_condition: |- {THRESHOLD_VALUE} + - id: '10012' + test_type: Incr_Avg_Shift + sql_flavor: salesforce_data360 + measure: |- + COALESCE(ABS( ({BASELINE_AVG} - (SUM({COLUMN_NAME}) - {BASELINE_SUM}) / NULLIF(CAST(COUNT({COLUMN_NAME}) AS FLOAT) - {BASELINE_VALUE_CT}, 0)) / {BASELINE_SD} ), 0) + test_operator: '>=' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1375' test_id: '1017' @@ -194,4 +202,12 @@ test_types: lookup_query: |- SELECT AVG(CAST("{COLUMN_NAME}" AS DECIMAL)) AS current_average, SUM(CAST("{COLUMN_NAME}" AS DECIMAL)) AS current_sum, NULLIF(COUNT("{COLUMN_NAME}"), 0) as current_value_count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" error_type: Test Results + - id: '10019' + test_id: 1017 + test_type: Incr_Avg_Shift + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT AVG(CAST("{COLUMN_NAME}" AS FLOAT)) AS current_average, SUM(CAST("{COLUMN_NAME}" AS FLOAT)) AS current_sum, NULLIF(CAST(COUNT("{COLUMN_NAME}" ) AS FLOAT), 0) as current_value_count FROM "{TABLE_NAME}"; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_LOV_All.yaml b/testgen/template/dbsetup_test_types/test_types_LOV_All.yaml index 2cf10836..59a07d79 100644 --- a/testgen/template/dbsetup_test_types/test_types_LOV_All.yaml +++ b/testgen/template/dbsetup_test_types/test_types_LOV_All.yaml @@ -115,6 +115,14 @@ test_types: test_operator: <> test_condition: |- {THRESHOLD_VALUE} + - id: '10013' + test_type: LOV_All + sql_flavor: salesforce_data360 + measure: |- + (SELECT ARRAY_JOIN(ARRAY_AGG(sub_val), '|') FROM (SELECT DISTINCT {COLUMN_NAME} AS sub_val FROM "{TABLE_NAME}" WHERE {SUBSET_CONDITION} ORDER BY 1 LIMIT 1000) sub_lov) + test_operator: '<>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1376' test_id: '1018' @@ -203,4 +211,12 @@ test_types: lookup_query: |- SELECT STRING_AGG("{COLUMN_NAME}", '|' ORDER BY "{COLUMN_NAME}") AS lov FROM (SELECT DISTINCT "{COLUMN_NAME}" FROM "{TARGET_SCHEMA}"."{TABLE_NAME}") HAVING STRING_AGG("{COLUMN_NAME}", '|' ORDER BY "{COLUMN_NAME}") <> {THRESHOLD_VALUE} LIMIT {LIMIT} error_type: Test Results + - id: '10020' + test_id: 1018 + test_type: LOV_All + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT ARRAY_JOIN(ARRAY_AGG(DISTINCT "{COLUMN_NAME}"), '|') AS lov FROM "{TABLE_NAME}" HAVING ARRAY_JOIN(ARRAY_AGG(DISTINCT "{COLUMN_NAME}"), '|') <> {THRESHOLD_VALUE} LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_LOV_Match.yaml b/testgen/template/dbsetup_test_types/test_types_LOV_Match.yaml index 768dd65b..38b8040c 100644 --- a/testgen/template/dbsetup_test_types/test_types_LOV_Match.yaml +++ b/testgen/template/dbsetup_test_types/test_types_LOV_Match.yaml @@ -221,6 +221,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10014' + test_type: LOV_Match + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN NULLIF({COLUMN_NAME}, '') NOT IN {BASELINE_VALUE} THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1377' test_id: '1019' @@ -298,4 +306,12 @@ test_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NOT NULL AND "{COLUMN_NAME}" NOT IN {BASELINE_VALUE} GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Test Results + - id: '10021' + test_id: 1019 + test_type: LOV_Match + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT NULLIF("{COLUMN_NAME}", '') AS "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE NULLIF("{COLUMN_NAME}", '') NOT IN {BASELINE_VALUE} GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Metric_Trend.yaml b/testgen/template/dbsetup_test_types/test_types_Metric_Trend.yaml index 31e17846..7675e9d1 100644 --- a/testgen/template/dbsetup_test_types/test_types_Metric_Trend.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Metric_Trend.yaml @@ -106,6 +106,14 @@ test_types: test_operator: NOT BETWEEN test_condition: |- {LOWER_TOLERANCE} AND {UPPER_TOLERANCE} + - id: '10015' + test_type: Metric_Trend + sql_flavor: salesforce_data360 + measure: |- + {CUSTOM_QUERY} + test_operator: NOT BETWEEN + test_condition: |- + {LOWER_TOLERANCE} AND {UPPER_TOLERANCE} target_data_lookups: - id: '1484' test_id: '1514' @@ -206,4 +214,15 @@ test_types: {UPPER_TOLERANCE} AS upper_bound FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" error_type: Test Results + - id: '10022' + test_id: 1514 + test_type: Metric_Trend + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT {CUSTOM_QUERY} AS current_count, + {LOWER_TOLERANCE} AS lower_bound, + {UPPER_TOLERANCE} AS upper_bound + FROM "{TABLE_NAME}"; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Min_Date.yaml b/testgen/template/dbsetup_test_types/test_types_Min_Date.yaml index a2762969..a3f9cbe1 100644 --- a/testgen/template/dbsetup_test_types/test_types_Min_Date.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Min_Date.yaml @@ -117,6 +117,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10016' + test_type: Min_Date + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN {COLUMN_NAME} < '{BASELINE_VALUE}' THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1378' test_id: '1020' @@ -194,4 +202,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" < CAST('{BASELINE_VALUE}' AS {COLUMN_TYPE}) GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Test Results + - id: '10023' + test_id: 1020 + test_type: Min_Date + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE CAST("{COLUMN_NAME}" AS DATE) < CAST('{BASELINE_VALUE}' AS DATE) GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Min_Val.yaml b/testgen/template/dbsetup_test_types/test_types_Min_Val.yaml index 3a852155..90f107f4 100644 --- a/testgen/template/dbsetup_test_types/test_types_Min_Val.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Min_Val.yaml @@ -117,6 +117,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10017' + test_type: Min_Val + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN {COLUMN_NAME} < {BASELINE_VALUE} - 1e-6 THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1379' test_id: '1021' @@ -193,4 +201,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", (ABS("{COLUMN_NAME}") - ABS({BASELINE_VALUE})) AS difference_from_baseline FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" < {BASELINE_VALUE} LIMIT {LIMIT} error_type: Test Results + - id: '10024' + test_id: 1021 + test_type: Min_Val + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", (ABS("{COLUMN_NAME}") - ABS({BASELINE_VALUE})) AS difference_from_baseline FROM "{TABLE_NAME}" WHERE "{COLUMN_NAME}" < {BASELINE_VALUE} LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Missing_Pct.yaml b/testgen/template/dbsetup_test_types/test_types_Missing_Pct.yaml index d85d0908..f4dd0a0a 100644 --- a/testgen/template/dbsetup_test_types/test_types_Missing_Pct.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Missing_Pct.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>=' test_condition: |- {THRESHOLD_VALUE} + - id: '10018' + test_type: Missing_Pct + sql_flavor: salesforce_data360 + measure: |- + ABS( 2.0 * ASIN( SQRT( CAST({BASELINE_VALUE_CT} AS FLOAT) / CAST({BASELINE_CT} AS FLOAT) ) ) - 2 * ASIN( SQRT( CAST(COUNT( {COLUMN_NAME} ) AS FLOAT) / CAST(NULLIF(COUNT(*), 0) AS FLOAT) )) ) + test_operator: '>=' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1380' test_id: '1022' @@ -194,4 +202,12 @@ test_types: lookup_query: |- SELECT * FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NULL LIMIT {LIMIT} error_type: Test Results + - id: '10025' + test_id: 1022 + test_type: Missing_Pct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT * FROM "{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NULL OR CAST("{COLUMN_NAME}" AS VARCHAR(255)) = '' LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Monthly_Rec_Ct.yaml b/testgen/template/dbsetup_test_types/test_types_Monthly_Rec_Ct.yaml index 8fd1fcdb..35580b34 100644 --- a/testgen/template/dbsetup_test_types/test_types_Monthly_Rec_Ct.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Monthly_Rec_Ct.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10019' + test_type: Monthly_Rec_Ct + sql_flavor: salesforce_data360 + measure: |- + (MAX(DATEDIFF('month', CAST({COLUMN_NAME} AS DATE), CAST('{RUN_DATE}' AS DATE))) - MIN(DATEDIFF('month', CAST({COLUMN_NAME} AS DATE), CAST('{RUN_DATE}' AS DATE))) + 1) - COUNT(DISTINCT DATEDIFF('month', CAST({COLUMN_NAME} AS DATE), CAST('{RUN_DATE}' AS DATE))) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1381' test_id: '1023' @@ -259,4 +267,12 @@ test_types: lookup_query: |- WITH Pass0 AS (SELECT 1 C FROM DUMMY UNION ALL SELECT 1 FROM DUMMY), Pass1 AS (SELECT 1 C FROM Pass0 A, Pass0 B), Pass2 AS (SELECT 1 C FROM Pass1 A, Pass1 B), Pass3 AS (SELECT 1 C FROM Pass2 A, Pass2 B), nums AS (SELECT ROW_NUMBER() OVER (ORDER BY C) - 1 AS rn FROM Pass3), bounds AS (SELECT TO_DATE(YEAR(MIN("{COLUMN_NAME}")) || '-' || LPAD(MONTH(MIN("{COLUMN_NAME}")), 2, '0') || '-01', 'YYYY-MM-DD') AS min_month, TO_DATE(YEAR(MAX("{COLUMN_NAME}")) || '-' || LPAD(MONTH(MAX("{COLUMN_NAME}")), 2, '0') || '-01', 'YYYY-MM-DD') AS max_month FROM "{TARGET_SCHEMA}"."{TABLE_NAME}"), daterange AS (SELECT ADD_MONTHS(b.min_month, n.rn) AS all_dates FROM bounds b, nums n WHERE ADD_MONTHS(b.min_month, n.rn) <= b.max_month), existing_periods AS (SELECT DISTINCT TO_DATE(YEAR("{COLUMN_NAME}") || '-' || LPAD(MONTH("{COLUMN_NAME}"), 2, '0') || '-01', 'YYYY-MM-DD') AS period, COUNT(1) AS period_count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY YEAR("{COLUMN_NAME}"), MONTH("{COLUMN_NAME}")) SELECT p.missing_period, p.prior_available_month, e.period_count AS prior_available_month_count, p.next_available_month, f.period_count AS next_available_month_count FROM (SELECT d.all_dates AS missing_period, MAX(b.period) AS prior_available_month, MIN(c.period) AS next_available_month FROM daterange d LEFT JOIN existing_periods a ON d.all_dates = a.period LEFT JOIN existing_periods b ON b.period < d.all_dates LEFT JOIN existing_periods c ON c.period > d.all_dates WHERE a.period IS NULL AND d.all_dates BETWEEN b.period AND c.period GROUP BY d.all_dates) p LEFT JOIN existing_periods e ON (p.prior_available_month = e.period) LEFT JOIN existing_periods f ON (p.next_available_month = f.period) ORDER BY p.missing_period LIMIT {LIMIT} error_type: Test Results + - id: '10026' + test_id: 1023 + test_type: Monthly_Rec_Ct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + WITH RECURSIVE daterange(all_dates) AS (SELECT CAST(DATE_TRUNC('month', MIN("{COLUMN_NAME}")) AS DATE) AS all_dates FROM "{TABLE_NAME}" UNION ALL SELECT CAST((d.all_dates + INTERVAL '1 month') AS DATE) AS all_dates FROM daterange d WHERE d.all_dates < (SELECT CAST(DATE_TRUNC('month', MAX("{COLUMN_NAME}")) AS DATE) FROM "{TABLE_NAME}") ), existing_periods AS ( SELECT DISTINCT CAST(DATE_TRUNC('month',"{COLUMN_NAME}") AS DATE) AS period, COUNT(1) AS period_count FROM "{TABLE_NAME}" GROUP BY CAST(DATE_TRUNC('month',"{COLUMN_NAME}") AS DATE) ) SELECT p.missing_period, p.prior_available_month, e.period_count AS prior_available_month_count, p.next_available_month, f.period_count AS next_available_month_count FROM (SELECT d.all_dates AS missing_period, MAX(b.period) AS prior_available_month, MIN(c.period) AS next_available_month FROM daterange d LEFT JOIN existing_periods a ON d.all_dates = a.period LEFT JOIN existing_periods b ON b.period < d.all_dates LEFT JOIN existing_periods c ON c.period > d.all_dates WHERE a.period IS NULL AND d.all_dates BETWEEN b.period AND c.period GROUP BY d.all_dates) p LEFT JOIN existing_periods e ON (p.prior_available_month = e.period) LEFT JOIN existing_periods f ON (p.next_available_month = f.period) ORDER BY p.missing_period LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Outlier_Pct_Above.yaml b/testgen/template/dbsetup_test_types/test_types_Outlier_Pct_Above.yaml index 6b26ccb1..5cf1493f 100644 --- a/testgen/template/dbsetup_test_types/test_types_Outlier_Pct_Above.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Outlier_Pct_Above.yaml @@ -122,6 +122,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10020' + test_type: Outlier_Pct_Above + sql_flavor: salesforce_data360 + measure: |- + CAST(SUM(CASE WHEN CAST({COLUMN_NAME} AS FLOAT) > {BASELINE_AVG}+(2.0*{BASELINE_SD}) THEN 1 ELSE 0 END) AS FLOAT) / CAST(NULLIF(COUNT({COLUMN_NAME}), 0) AS FLOAT) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1382' test_id: '1024' @@ -199,4 +207,12 @@ test_types: lookup_query: |- SELECT ({BASELINE_AVG} + (2*{BASELINE_SD})) AS outlier_threshold, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE CAST("{COLUMN_NAME}" AS DECIMAL) > ({BASELINE_AVG} + (2*{BASELINE_SD})) GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC error_type: Test Results + - id: '10027' + test_id: 1024 + test_type: Outlier_Pct_Above + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT ({BASELINE_AVG} + (2*{BASELINE_SD})) AS outlier_threshold, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE CAST("{COLUMN_NAME}" AS FLOAT) > ({BASELINE_AVG} + (2*{BASELINE_SD})) GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Outlier_Pct_Below.yaml b/testgen/template/dbsetup_test_types/test_types_Outlier_Pct_Below.yaml index a2354e6e..fa88ab8c 100644 --- a/testgen/template/dbsetup_test_types/test_types_Outlier_Pct_Below.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Outlier_Pct_Below.yaml @@ -122,6 +122,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10021' + test_type: Outlier_Pct_Below + sql_flavor: salesforce_data360 + measure: |- + CAST(SUM(CASE WHEN CAST({COLUMN_NAME} AS FLOAT) < {BASELINE_AVG}-(2.0*{BASELINE_SD}) THEN 1 ELSE 0 END) AS FLOAT) / CAST(NULLIF(COUNT({COLUMN_NAME}), 0) AS FLOAT) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1383' test_id: '1025' @@ -199,4 +207,12 @@ test_types: lookup_query: |- SELECT ({BASELINE_AVG} - (2*{BASELINE_SD})) AS outlier_threshold, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE CAST("{COLUMN_NAME}" AS DECIMAL) < ({BASELINE_AVG} - (2*{BASELINE_SD})) GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC error_type: Test Results + - id: '10028' + test_id: 1025 + test_type: Outlier_Pct_Below + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT ({BASELINE_AVG} - (2*{BASELINE_SD})) AS outlier_threshold, "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE CAST("{COLUMN_NAME}" AS FLOAT) < ({BASELINE_AVG} - (2*{BASELINE_SD})) GROUP BY "{COLUMN_NAME}" ORDER BY "{COLUMN_NAME}" DESC; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Pattern_Match.yaml b/testgen/template/dbsetup_test_types/test_types_Pattern_Match.yaml index b3d0862f..6998da47 100644 --- a/testgen/template/dbsetup_test_types/test_types_Pattern_Match.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Pattern_Match.yaml @@ -117,6 +117,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10022' + test_type: Pattern_Match + sql_flavor: salesforce_data360 + measure: |- + COUNT(NULLIF({COLUMN_NAME}, '')) - SUM(CASE WHEN REGEXP_LIKE(CAST(NULLIF({COLUMN_NAME}, '') AS VARCHAR), '{BASELINE_VALUE}') THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1384' test_id: '1026' @@ -194,4 +202,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE NOT NULLIF(TO_VARCHAR("{COLUMN_NAME}"), '') LIKE_REGEXPR '{BASELINE_VALUE}' GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Test Results + - id: '10029' + test_id: 1026 + test_type: Pattern_Match + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE NOT REGEXP_LIKE(CAST(NULLIF("{COLUMN_NAME}", '') AS VARCHAR), '{BASELINE_VALUE}') GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Recency.yaml b/testgen/template/dbsetup_test_types/test_types_Recency.yaml index 088a3a92..34945e9b 100644 --- a/testgen/template/dbsetup_test_types/test_types_Recency.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Recency.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10023' + test_type: Recency + sql_flavor: salesforce_data360 + measure: |- + DATEDIFF('day', CAST(MAX({COLUMN_NAME}) AS DATE), CAST('{RUN_DATE}' AS DATE)) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1385' test_id: '1028' @@ -203,4 +211,12 @@ test_types: lookup_query: |- SELECT DISTINCT col AS latest_date_available, TO_DATE('{TEST_DATE}', 'YYYY-MM-DD HH24:MI:SS') AS test_run_date FROM (SELECT MAX("{COLUMN_NAME}") AS col FROM "{TARGET_SCHEMA}"."{TABLE_NAME}") WHERE <%DATEDIFF_DAY;col;TO_DATE('{TEST_DATE}', 'YYYY-MM-DD HH24:MI:SS')%> > {THRESHOLD_VALUE} LIMIT {LIMIT} error_type: Test Results + - id: '10030' + test_id: 1028 + test_type: Recency + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT col AS latest_date_available, CAST('{TEST_DATE}' AS DATE) as test_run_date FROM (SELECT MAX("{COLUMN_NAME}") AS col FROM "{TABLE_NAME}") a WHERE DATEDIFF('day', CAST(col AS DATE), CAST('{TEST_DATE}' AS DATE)) > {THRESHOLD_VALUE} LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Required.yaml b/testgen/template/dbsetup_test_types/test_types_Required.yaml index 625b135f..cb294860 100644 --- a/testgen/template/dbsetup_test_types/test_types_Required.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Required.yaml @@ -116,6 +116,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10024' + test_type: Required + sql_flavor: salesforce_data360 + measure: |- + COUNT(*) - COUNT({COLUMN_NAME}) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1386' test_id: '1030' @@ -192,4 +200,12 @@ test_types: lookup_query: |- SELECT * FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NULL LIMIT {LIMIT} error_type: Test Results + - id: '10031' + test_id: 1030 + test_type: Required + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT * FROM "{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NULL LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Row_Ct.yaml b/testgen/template/dbsetup_test_types/test_types_Row_Ct.yaml index b5c4459d..06c3d62e 100644 --- a/testgen/template/dbsetup_test_types/test_types_Row_Ct.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Row_Ct.yaml @@ -115,6 +115,13 @@ test_types: test_operator: < test_condition: |- {THRESHOLD_VALUE} + - id: '10025' + test_type: Row_Ct + sql_flavor: salesforce_data360 + measure: COUNT(*) + test_operator: '<' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1387' test_id: '1031' @@ -195,4 +202,12 @@ test_types: lookup_query: |- WITH CTE AS (SELECT COUNT(*) AS current_count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}") SELECT current_count, ABS(ROUND(100 * (current_count - {THRESHOLD_VALUE}) / {THRESHOLD_VALUE}, 2)) AS row_count_pct_decrease FROM cte WHERE current_count < {THRESHOLD_VALUE} error_type: Test Results + - id: '10032' + test_id: 1031 + test_type: Row_Ct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + WITH cte AS (SELECT COUNT(*) AS current_count FROM "{TABLE_NAME}") SELECT current_count, ABS(ROUND(100 * CAST((current_count - {THRESHOLD_VALUE}) AS NUMERIC) / CAST({THRESHOLD_VALUE} AS NUMERIC),2)) AS row_count_pct_decrease FROM cte WHERE current_count < {THRESHOLD_VALUE}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Row_Ct_Pct.yaml b/testgen/template/dbsetup_test_types/test_types_Row_Ct_Pct.yaml index 05efdf4c..47cbf379 100644 --- a/testgen/template/dbsetup_test_types/test_types_Row_Ct_Pct.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Row_Ct_Pct.yaml @@ -116,6 +116,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10026' + test_type: Row_Ct_Pct + sql_flavor: salesforce_data360 + measure: |- + ABS(ROUND(100.0 * CAST((COUNT(*) - {BASELINE_CT}) AS DECIMAL(18,4)) / CAST({BASELINE_CT} AS DECIMAL(18,4)), 2)) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1388' test_id: '1032' @@ -195,4 +203,12 @@ test_types: lookup_query: |- WITH CTE AS (SELECT COUNT(*) AS current_count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}") SELECT current_count, {BASELINE_CT} AS baseline_count, ABS(ROUND(100 * (current_count - {BASELINE_CT}) / {BASELINE_CT}, 2)) AS row_count_pct_difference FROM cte error_type: Test Results + - id: '10033' + test_id: 1032 + test_type: Row_Ct_Pct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT COUNT(*) AS current_count, {BASELINE_CT} AS baseline_count, ABS(ROUND(100 * CAST((COUNT(*) - {BASELINE_CT}) AS NUMERIC) / CAST({BASELINE_CT} AS NUMERIC), 2)) AS row_count_pct_difference FROM "{TABLE_NAME}"; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Schema_Drift.yaml b/testgen/template/dbsetup_test_types/test_types_Schema_Drift.yaml index e5c908a7..4992ba2c 100644 --- a/testgen/template/dbsetup_test_types/test_types_Schema_Drift.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Schema_Drift.yaml @@ -530,3 +530,58 @@ test_types: AS result_message, column_adds + column_drops + column_mods AS result_measure FROM table_changes; + - id: '10010' + test_type: Schema_Drift + sql_flavor: salesforce_data360 + template: |- + WITH prev_test AS ( + SELECT MAX(test_starttime) AS last_run_time + FROM {APP_SCHEMA_NAME}.test_runs + WHERE test_suite_id = '{TEST_SUITE_ID}'::UUID + -- Ignore current run + AND id <> '{TEST_RUN_ID}'::UUID + ), + table_changes AS ( + SELECT + dsl.table_name, + MAX(prev_test.last_run_time) as window_start, + MAX(CASE WHEN dsl.column_id IS NULL AND dsl.change = 'A' THEN dsl.change_date ELSE NULL END) as last_add_date, + MAX(CASE WHEN dsl.column_id IS NULL AND dsl.change = 'D' THEN dsl.change_date ELSE NULL END) as last_drop_date, + COUNT(*) FILTER (WHERE dsl.column_id IS NOT NULL AND dsl.change = 'A') AS column_adds, + COUNT(*) FILTER (WHERE dsl.column_id IS NOT NULL AND dsl.change = 'D') AS column_drops, + COUNT(*) FILTER (WHERE dsl.column_id IS NOT NULL AND dsl.change = 'M') AS column_mods + FROM {APP_SCHEMA_NAME}.data_structure_log dsl + CROSS JOIN prev_test + WHERE dsl.table_groups_id = '{TABLE_GROUPS_ID}'::UUID + -- if no previous tests, this comparision yelds null and nothing is counted + AND dsl.change_date > prev_test.last_run_time + GROUP BY dsl.table_name + ) + SELECT + '{TEST_TYPE}' AS test_type, + '{TEST_DEFINITION_ID}' AS test_definition_id, + '{TEST_SUITE_ID}' AS test_suite_id, + '{TEST_RUN_ID}' AS test_run_id, + '{RUN_DATE}' AS test_time, + '{SCHEMA_NAME}' AS schema_name, + table_name, + '{INPUT_PARAMETERS}' AS input_parameters, + (CASE + WHEN last_add_date IS NOT NULL AND (last_drop_date IS NULL OR last_add_date > last_drop_date) THEN 'A' + WHEN last_drop_date IS NOT NULL AND (last_add_date IS NULL OR last_drop_date > last_add_date) THEN 'D' + ELSE 'M' + END) + || '|' || column_adds + || '|' || column_drops + || '|' || column_mods + || '|' || window_start::TEXT + AS result_signal, + 0 AS result_code, + CASE WHEN last_add_date IS NOT NULL AND (last_drop_date IS NULL OR last_add_date > last_drop_date) THEN 'Table added. ' ELSE '' END + || CASE WHEN last_drop_date IS NOT NULL AND (last_add_date IS NULL OR last_drop_date > last_add_date) THEN 'Table dropped. ' ELSE '' END + || CASE WHEN column_adds > 0 THEN column_adds || ' columns added. ' ELSE '' END + || CASE WHEN column_drops > 0 THEN column_drops || ' columns dropped. ' ELSE '' END + || CASE WHEN column_mods > 0 THEN column_mods || ' columns modified. ' ELSE '' END + AS result_message, + column_adds + column_drops + column_mods AS result_measure + FROM table_changes; diff --git a/testgen/template/dbsetup_test_types/test_types_Street_Addr_Pattern.yaml b/testgen/template/dbsetup_test_types/test_types_Street_Addr_Pattern.yaml index 7956ef0a..36b83009 100644 --- a/testgen/template/dbsetup_test_types/test_types_Street_Addr_Pattern.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Street_Addr_Pattern.yaml @@ -118,6 +118,14 @@ test_types: test_operator: < test_condition: |- {THRESHOLD_VALUE} + - id: '10027' + test_type: Street_Addr_Pattern + sql_flavor: salesforce_data360 + measure: |- + 100.0*CAST(SUM(CASE WHEN REGEXP_LIKE({COLUMN_NAME}, '^[0-9]{1,5}[a-zA-Z]?\s\w{1,5}\.?\s?\w*\s?\w*\s[a-zA-Z]{1,6}\.?\s?[0-9]{0,5}[A-Z]{0,1}$') THEN 1 ELSE 0 END) AS FLOAT) / CAST(NULLIF(COUNT({COLUMN_NAME}), 0) AS FLOAT) + test_operator: '<' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1389' test_id: '1033' @@ -196,4 +204,12 @@ test_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE NOT TO_VARCHAR("{COLUMN_NAME}") LIKE_REGEXPR '^[0-9]{1,5}[a-zA-Z]?[[:space:]][[:alnum:]_]{1,5}\.?[[:space:]]?[[:alnum:]_]*[[:space:]]?[[:alnum:]_]*[[:space:]][a-zA-Z]{1,6}\.?[[:space:]]?[0-9]{0,5}[A-Z]{0,1}$' GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT} error_type: Test Results + - id: '10034' + test_id: 1033 + test_type: Street_Addr_Pattern + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE NOT REGEXP_LIKE("{COLUMN_NAME}", '^[0-9]{1,5}[a-zA-Z]?\s\w{1,5}\.?\s?\w*\s?\w*\s[a-zA-Z]{1,6}\.?\s?[0-9]{0,5}[A-Z]{0,1}$') GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Table_Freshness.yaml b/testgen/template/dbsetup_test_types/test_types_Table_Freshness.yaml index 76823e83..d81e834e 100644 --- a/testgen/template/dbsetup_test_types/test_types_Table_Freshness.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Table_Freshness.yaml @@ -329,3 +329,35 @@ test_types: FROM {QUOTE}{SCHEMA_NAME}{QUOTE}.{QUOTE}{TABLE_NAME}{QUOTE} WHERE {SUBSET_CONDITION} ) test + - id: '10011' + test_type: Table_Freshness + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + fingerprint as result_signal, + CASE + WHEN '{LOWER_TOLERANCE}' = 'NULL' OR fingerprint = '{LOWER_TOLERANCE}' THEN 0 + ELSE 1 + END AS result_code, + CASE + WHEN '{LOWER_TOLERANCE}' = 'NULL' OR fingerprint = '{LOWER_TOLERANCE}' THEN 'No table change detected.' + ELSE 'Table change detected.' + END AS result_message, + CASE + WHEN '{LOWER_TOLERANCE}' = 'NULL' OR fingerprint = '{LOWER_TOLERANCE}' THEN 0 + ELSE 1 + END AS result_measure + FROM ( SELECT MD5({CUSTOM_QUERY}) as fingerprint + FROM {QUOTE}{TABLE_NAME}{QUOTE} + WHERE {SUBSET_CONDITION} + ) test; diff --git a/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Gain.yaml b/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Gain.yaml index 61346177..13bdb85c 100644 --- a/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Gain.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Gain.yaml @@ -199,6 +199,26 @@ test_types: GROUP BY {COLUMN_NAME_NO_QUOTES} LIMIT {LIMIT} error_type: Test Results + - id: '10035' + test_id: 1508 + test_type: Timeframe_Combo_Gain + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -2 * {WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + AND {WINDOW_DATE_COLUMN} < DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + GROUP BY {COLUMN_NAME_NO_QUOTES} + EXCEPT + SELECT {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + GROUP BY {COLUMN_NAME_NO_QUOTES} + LIMIT {LIMIT}; + error_type: Test Results test_templates: - id: '2507' test_type: Timeframe_Combo_Gain @@ -602,3 +622,49 @@ test_types: AND {WINDOW_DATE_COLUMN} >= ADD_DAYS((SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{SCHEMA_NAME}"."{TABLE_NAME}"), -{WINDOW_DAYS}) GROUP BY {COLUMN_NAME_NO_QUOTES} ) test + - id: '10012' + test_type: Timeframe_Combo_Gain + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + NULL as result_signal, + CASE WHEN COUNT (*) > {SKIP_ERRORS} THEN 0 ELSE 1 END as result_code, + CASE + WHEN COUNT(*) > 0 THEN + CONCAT( + CONCAT( CAST(COUNT(*) AS VARCHAR), ' error(s) identified, ' ), + CONCAT( + CASE + WHEN COUNT(*) > {SKIP_ERRORS} THEN 'exceeding limit of ' + ELSE 'within limit of ' + END, + '{SKIP_ERRORS}.' + ) + ) + ELSE 'No errors found.' + END AS result_message, + COUNT(*) as result_measure + FROM ( + SELECT {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -2 * {WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + AND {WINDOW_DATE_COLUMN} < DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + GROUP BY {COLUMN_NAME_NO_QUOTES} + EXCEPT + SELECT {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + GROUP BY {COLUMN_NAME_NO_QUOTES} + ) test; diff --git a/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Match.yaml b/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Match.yaml index e3d2086a..6b10231d 100644 --- a/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Match.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Match.yaml @@ -340,6 +340,40 @@ test_types: LIMIT {LIMIT_2} ) error_type: Test Results + - id: '10036' + test_id: 1509 + test_type: Timeframe_Combo_Match + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + WITH prior_diff AS ( + SELECT 'Prior Timeframe' as missing_from, {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + EXCEPT + SELECT 'Prior Timeframe' as missing_from, {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -2 * {WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + AND {WINDOW_DATE_COLUMN} < DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + ), + latest_diff AS ( + SELECT 'Latest Timeframe' as missing_from, {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -2 * {WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + AND {WINDOW_DATE_COLUMN} < DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + EXCEPT + SELECT 'Latest Timeframe' as missing_from, {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + ) + SELECT * FROM (SELECT * FROM prior_diff LIMIT {LIMIT_2}) p + UNION ALL + SELECT * FROM (SELECT * FROM latest_diff LIMIT {LIMIT_2}) l + error_type: Test Results test_templates: - id: '2508' test_type: Timeframe_Combo_Match @@ -881,3 +915,62 @@ test_types: AND {WINDOW_DATE_COLUMN} >= ADD_DAYS((SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{SCHEMA_NAME}"."{TABLE_NAME}"), -{WINDOW_DAYS}) ) ) test + - id: '10013' + test_type: Timeframe_Combo_Match + sql_flavor: salesforce_data360 + template: |- + SELECT '{TEST_TYPE}' as test_type, + '{TEST_DEFINITION_ID}' as test_definition_id, + '{TEST_SUITE_ID}' as test_suite_id, + '{TEST_RUN_ID}' as test_run_id, + '{RUN_DATE}' as test_time, + '{SCHEMA_NAME}' as schema_name, + '{TABLE_NAME}' as table_name, + '{COLUMN_NAME_NO_QUOTES}' as column_names, + '{SKIP_ERRORS}' as threshold_value, + {SKIP_ERRORS} as skip_errors, + '{INPUT_PARAMETERS}' as input_parameters, + NULL as result_signal, + CASE WHEN COUNT (*) > {SKIP_ERRORS} THEN 0 ELSE 1 END as result_code, + CASE + WHEN COUNT(*) > 0 THEN + CONCAT( + CONCAT( CAST(COUNT(*) AS VARCHAR), ' error(s) identified, ' ), + CONCAT( + CASE + WHEN COUNT(*) > {SKIP_ERRORS} THEN 'exceeding limit of ' + ELSE 'within limit of ' + END, + '{SKIP_ERRORS}.' + ) + ) + ELSE 'No errors found.' + END AS result_message, + COUNT(*) as result_measure + FROM ( + ( + SELECT 'Prior Timeframe' as missing_from, {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + EXCEPT + SELECT 'Prior Timeframe' as missing_from, {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -2 * {WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + AND {WINDOW_DATE_COLUMN} < DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + ) + UNION ALL + ( + SELECT 'Latest Timeframe' as missing_from, {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -2 * {WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + AND {WINDOW_DATE_COLUMN} < DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + EXCEPT + SELECT 'Latest Timeframe' as missing_from, {COLUMN_NAME_NO_QUOTES} + FROM "{TABLE_NAME}" + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= DATE_ADD('day', -{WINDOW_DAYS}, (SELECT MAX({WINDOW_DATE_COLUMN}) FROM "{TABLE_NAME}")) + ) + ) test; diff --git a/testgen/template/dbsetup_test_types/test_types_US_State.yaml b/testgen/template/dbsetup_test_types/test_types_US_State.yaml index a14181e8..397611df 100644 --- a/testgen/template/dbsetup_test_types/test_types_US_State.yaml +++ b/testgen/template/dbsetup_test_types/test_types_US_State.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10028' + test_type: US_State + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN NULLIF({COLUMN_NAME}, '') NOT IN ('AL','AK','AS','AZ','AR','CA','CO','CT','DE','DC','FM','FL','GA','GU','HI','ID','IL','IN','IA','KS','KY','LA','ME','MH','MD','MA','MI','MN','MS','MO','MT','NE','NV','NH','NJ','NM','NY','NC','ND','MP','OH','OK','OR','PW','PA','PR','RI','SC','SD','TN','TX','UT','VT','VI','VA','WA','WV','WI','WY','AE','AP','AA') THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1390' test_id: '1036' @@ -195,4 +203,12 @@ test_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE "{COLUMN_NAME}" IS NOT NULL AND "{COLUMN_NAME}" NOT IN ('AL','AK','AS','AZ','AR','CA','CO','CT','DE','DC','FM','FL','GA','GU','HI','ID','IL','IN','IA','KS','KY','LA','ME','MH','MD','MA','MI','MN','MS','MO','MT','NE','NV','NH','NJ','NM','NY','NC','ND','MP','OH','OK','OR','PW','PA','PR','RI','SC','SD','TN','TX','UT','VT','VI','VA','WA','WV','WI','WY','AE','AP','AA') GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT} error_type: Test Results + - id: '10037' + test_id: 1036 + test_type: US_State + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" WHERE NULLIF("{COLUMN_NAME}", '') NOT IN ('AL','AK','AS','AZ','AR','CA','CO','CT','DE','DC','FM','FL','GA','GU','HI','ID','IL','IN','IA','KS','KY','LA','ME','MH','MD','MA','MI','MN','MS','MO','MT','NE','NV','NH','NJ','NM','NY','NC','ND','MP','OH','OK','OR','PW','PA','PR','RI','SC','SD','TN','TX','UT','VT','VI','VA','WA','WV','WI','WY','AE','AP','AA') GROUP BY "{COLUMN_NAME}" LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Unique.yaml b/testgen/template/dbsetup_test_types/test_types_Unique.yaml index abf22dae..e1b5b661 100644 --- a/testgen/template/dbsetup_test_types/test_types_Unique.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Unique.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10029' + test_type: Unique + sql_flavor: salesforce_data360 + measure: |- + COUNT(*) - COUNT(DISTINCT {COLUMN_NAME}) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1391' test_id: '1034' @@ -196,4 +204,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" HAVING COUNT(*) > 1 ORDER BY COUNT(*) DESC LIMIT {LIMIT} error_type: Test Results + - id: '10038' + test_id: 1034 + test_type: Unique + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" HAVING COUNT(*) > 1 ORDER BY COUNT(*) DESC LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Unique_Pct.yaml b/testgen/template/dbsetup_test_types/test_types_Unique_Pct.yaml index 6e8767ae..6cb5908a 100644 --- a/testgen/template/dbsetup_test_types/test_types_Unique_Pct.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Unique_Pct.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>=' test_condition: |- {THRESHOLD_VALUE} + - id: '10030' + test_type: Unique_Pct + sql_flavor: salesforce_data360 + measure: |- + ABS( 2.0 * ASIN( SQRT(CAST({BASELINE_UNIQUE_CT} AS FLOAT) / CAST({BASELINE_VALUE_CT} AS FLOAT) ) ) - 2 * ASIN( SQRT( CAST(COUNT( DISTINCT {COLUMN_NAME} ) AS FLOAT) / CAST(NULLIF(COUNT( {COLUMN_NAME} ), 0) AS FLOAT) )) ) + test_operator: '>=' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1392' test_id: '1035' @@ -195,4 +203,12 @@ test_types: lookup_query: |- SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT} error_type: Test Results + - id: '10039' + test_id: 1035 + test_type: Unique_Pct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT DISTINCT "{COLUMN_NAME}", COUNT(*) AS count FROM "{TABLE_NAME}" GROUP BY "{COLUMN_NAME}" ORDER BY COUNT(*) DESC LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Valid_Characters.yaml b/testgen/template/dbsetup_test_types/test_types_Valid_Characters.yaml index 6110a2f9..cd73a08c 100644 --- a/testgen/template/dbsetup_test_types/test_types_Valid_Characters.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Valid_Characters.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10031' + test_type: Valid_Characters + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE({COLUMN_NAME}, CHR(160), ''), CHR(8203), ''), CHR(65279), ''), CHR(8239), ''), CHR(8201), ''), CHR(12288), ''), CHR(8204), '') <> {COLUMN_NAME} OR {COLUMN_NAME} LIKE ' %' OR {COLUMN_NAME} LIKE '''%''' OR {COLUMN_NAME} LIKE '"%"' THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1397' test_id: '1043' @@ -199,4 +207,12 @@ test_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS record_ct FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE("{COLUMN_NAME}", NCHAR(160), ''), NCHAR(8203), ''), NCHAR(65279), ''), NCHAR(8239), ''), NCHAR(8201), ''), NCHAR(12288), ''), NCHAR(8204), '') <> "{COLUMN_NAME}" OR "{COLUMN_NAME}" LIKE ' %' OR "{COLUMN_NAME}" LIKE '''%''' OR "{COLUMN_NAME}" LIKE '"%"' GROUP BY "{COLUMN_NAME}" ORDER BY record_ct DESC LIMIT {LIMIT} error_type: Test Results + - id: '10040' + test_id: 1043 + test_type: Valid_Characters + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) AS record_ct FROM "{TABLE_NAME}" WHERE REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE("{COLUMN_NAME}", CHR(160), ''), CHR(8203), ''), CHR(65279), ''), CHR(8239), ''), CHR(8201), ''), CHR(12288), ''), CHR(8204), '') <> "{COLUMN_NAME}" OR "{COLUMN_NAME}" LIKE ' %' OR "{COLUMN_NAME}" LIKE '''%''' OR "{COLUMN_NAME}" LIKE '"%"' GROUP BY "{COLUMN_NAME}" ORDER BY record_ct DESC LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Valid_Month.yaml b/testgen/template/dbsetup_test_types/test_types_Valid_Month.yaml index a5a8fbcd..fab14ff1 100644 --- a/testgen/template/dbsetup_test_types/test_types_Valid_Month.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Valid_Month.yaml @@ -117,5 +117,13 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10032' + test_type: Valid_Month + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN NULLIF({COLUMN_NAME}, '') NOT IN ({BASELINE_VALUE}) THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: [] test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Valid_US_Zip.yaml b/testgen/template/dbsetup_test_types/test_types_Valid_US_Zip.yaml index e5225b67..e380caef 100644 --- a/testgen/template/dbsetup_test_types/test_types_Valid_US_Zip.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Valid_US_Zip.yaml @@ -116,6 +116,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10033' + test_type: Valid_US_Zip + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN REGEXP_REPLACE({COLUMN_NAME}, '[0-9]', '9', 'g') NOT IN ('99999', '999999999', '99999-9999') THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1398' test_id: '1044' @@ -194,4 +202,12 @@ test_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS record_ct FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE REPLACE_REGEXPR('[0-9]' IN "{COLUMN_NAME}" WITH '9') NOT IN ('99999', '999999999', '99999-9999') GROUP BY "{COLUMN_NAME}" ORDER BY record_ct DESC LIMIT {LIMIT} error_type: Test Results + - id: '10041' + test_id: 1044 + test_type: Valid_US_Zip + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) AS record_ct FROM "{TABLE_NAME}" WHERE REGEXP_REPLACE("{COLUMN_NAME}", '[0-9]', '9', 'g') NOT IN ('99999', '999999999', '99999-9999') GROUP BY "{COLUMN_NAME}" ORDER BY record_ct DESC LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Valid_US_Zip3.yaml b/testgen/template/dbsetup_test_types/test_types_Valid_US_Zip3.yaml index 5d174ae7..45218af9 100644 --- a/testgen/template/dbsetup_test_types/test_types_Valid_US_Zip3.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Valid_US_Zip3.yaml @@ -117,6 +117,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10034' + test_type: Valid_US_Zip3 + sql_flavor: salesforce_data360 + measure: |- + SUM(CASE WHEN REGEXP_REPLACE({COLUMN_NAME}, '[0-9]', '9', 'g') <> '999' THEN 1 ELSE 0 END) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1399' test_id: '1045' @@ -195,4 +203,12 @@ test_types: lookup_query: |- SELECT "{COLUMN_NAME}", COUNT(*) AS record_ct FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" WHERE REPLACE_REGEXPR('[0-9]' IN "{COLUMN_NAME}" WITH '9') <> '999' GROUP BY "{COLUMN_NAME}" ORDER BY record_ct DESC LIMIT {LIMIT} error_type: Test Results + - id: '10042' + test_id: 1045 + test_type: Valid_US_Zip3 + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT "{COLUMN_NAME}", COUNT(*) AS record_ct FROM "{TABLE_NAME}" WHERE REGEXP_REPLACE("{COLUMN_NAME}", '[0-9]', '9', 'g') <> '999' GROUP BY "{COLUMN_NAME}" ORDER BY record_ct DESC LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Variability_Decrease.yaml b/testgen/template/dbsetup_test_types/test_types_Variability_Decrease.yaml index dda3e907..bb671fd8 100644 --- a/testgen/template/dbsetup_test_types/test_types_Variability_Decrease.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Variability_Decrease.yaml @@ -122,6 +122,14 @@ test_types: test_operator: < test_condition: |- {THRESHOLD_VALUE} + - id: '10035' + test_type: Variability_Decrease + sql_flavor: salesforce_data360 + measure: |- + 100.0*STDDEV(CAST({COLUMN_NAME} AS FLOAT))/CAST({BASELINE_SD} AS FLOAT) + test_operator: '<' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1395' test_id: '1041' @@ -196,4 +204,12 @@ test_types: lookup_query: |- SELECT STDDEV(CAST("{COLUMN_NAME}" AS DECIMAL)) as current_standard_deviation FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" error_type: Test Results + - id: '10043' + test_id: 1041 + test_type: Variability_Decrease + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT STDDEV(CAST("{COLUMN_NAME}" AS FLOAT)) as current_standard_deviation FROM "{TABLE_NAME}"; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Variability_Increase.yaml b/testgen/template/dbsetup_test_types/test_types_Variability_Increase.yaml index 73b0b48d..54e11245 100644 --- a/testgen/template/dbsetup_test_types/test_types_Variability_Increase.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Variability_Increase.yaml @@ -126,6 +126,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10036' + test_type: Variability_Increase + sql_flavor: salesforce_data360 + measure: |- + 100.0*STDDEV(CAST({COLUMN_NAME} AS FLOAT))/CAST({BASELINE_SD} AS FLOAT) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1394' test_id: '1040' @@ -200,4 +208,12 @@ test_types: lookup_query: |- SELECT STDDEV(CAST("{COLUMN_NAME}" AS DECIMAL)) as current_standard_deviation FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" error_type: Test Results + - id: '10044' + test_id: 1040 + test_type: Variability_Increase + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT STDDEV(CAST("{COLUMN_NAME}" AS FLOAT)) as current_standard_deviation FROM "{TABLE_NAME}"; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Volume_Trend.yaml b/testgen/template/dbsetup_test_types/test_types_Volume_Trend.yaml index 521688f6..67c9ff29 100644 --- a/testgen/template/dbsetup_test_types/test_types_Volume_Trend.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Volume_Trend.yaml @@ -107,6 +107,14 @@ test_types: test_operator: NOT BETWEEN test_condition: |- {LOWER_TOLERANCE} AND {UPPER_TOLERANCE} + - id: '10037' + test_type: Volume_Trend + sql_flavor: salesforce_data360 + measure: |- + {CUSTOM_QUERY} + test_operator: NOT BETWEEN + test_condition: |- + {LOWER_TOLERANCE} AND {UPPER_TOLERANCE} target_data_lookups: - id: '1477' test_id: '1513' @@ -207,4 +215,15 @@ test_types: {UPPER_TOLERANCE} AS upper_bound FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" error_type: Test Results + - id: '10045' + test_id: 1513 + test_type: Volume_Trend + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + SELECT {CUSTOM_QUERY} AS current_count, + {LOWER_TOLERANCE} AS lower_bound, + {UPPER_TOLERANCE} AS upper_bound + FROM "{TABLE_NAME}"; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Weekly_Rec_Ct.yaml b/testgen/template/dbsetup_test_types/test_types_Weekly_Rec_Ct.yaml index 73a115dc..bf0e91df 100644 --- a/testgen/template/dbsetup_test_types/test_types_Weekly_Rec_Ct.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Weekly_Rec_Ct.yaml @@ -118,6 +118,14 @@ test_types: test_operator: '>' test_condition: |- {THRESHOLD_VALUE} + - id: '10038' + test_type: Weekly_Rec_Ct + sql_flavor: salesforce_data360 + measure: |- + MAX(DATEDIFF('week', CAST('1800-01-01' AS DATE), CAST({COLUMN_NAME} AS DATE))) - MIN(DATEDIFF('week', CAST('1800-01-01' AS DATE), CAST({COLUMN_NAME} AS DATE)))+1 - COUNT(DISTINCT DATEDIFF('week', CAST('1800-01-01' AS DATE), CAST({COLUMN_NAME} AS DATE))) + test_operator: '>' + test_condition: |- + {THRESHOLD_VALUE} target_data_lookups: - id: '1393' test_id: '1037' @@ -259,4 +267,12 @@ test_types: lookup_query: |- WITH Pass0 AS (SELECT 1 C FROM DUMMY UNION ALL SELECT 1 FROM DUMMY), Pass1 AS (SELECT 1 C FROM Pass0 A, Pass0 B), Pass2 AS (SELECT 1 C FROM Pass1 A, Pass1 B), Pass3 AS (SELECT 1 C FROM Pass2 A, Pass2 B), nums AS (SELECT ROW_NUMBER() OVER (ORDER BY C) - 1 AS rn FROM Pass3), bounds AS (SELECT ADD_DAYS(CAST(MIN("{COLUMN_NAME}") AS DATE), -WEEKDAY(CAST(MIN("{COLUMN_NAME}") AS DATE))) AS min_week, ADD_DAYS(CAST(MAX("{COLUMN_NAME}") AS DATE), -WEEKDAY(CAST(MAX("{COLUMN_NAME}") AS DATE))) AS max_week FROM "{TARGET_SCHEMA}"."{TABLE_NAME}"), daterange AS (SELECT ADD_DAYS(b.min_week, n.rn * 7) AS all_dates FROM bounds b, nums n WHERE ADD_DAYS(b.min_week, n.rn * 7) <= b.max_week), existing_periods AS (SELECT DISTINCT ADD_DAYS(CAST("{COLUMN_NAME}" AS DATE), -WEEKDAY(CAST("{COLUMN_NAME}" AS DATE))) AS period, COUNT(1) AS period_count FROM "{TARGET_SCHEMA}"."{TABLE_NAME}" GROUP BY ADD_DAYS(CAST("{COLUMN_NAME}" AS DATE), -WEEKDAY(CAST("{COLUMN_NAME}" AS DATE)))) SELECT p.missing_period, p.prior_available_week, e.period_count AS prior_available_week_count, p.next_available_week, f.period_count AS next_available_week_count FROM (SELECT d.all_dates AS missing_period, MAX(b.period) AS prior_available_week, MIN(c.period) AS next_available_week FROM daterange d LEFT JOIN existing_periods a ON d.all_dates = a.period LEFT JOIN existing_periods b ON b.period < d.all_dates LEFT JOIN existing_periods c ON c.period > d.all_dates WHERE a.period IS NULL AND d.all_dates BETWEEN b.period AND c.period GROUP BY d.all_dates) p LEFT JOIN existing_periods e ON (p.prior_available_week = e.period) LEFT JOIN existing_periods f ON (p.next_available_week = f.period) ORDER BY p.missing_period LIMIT {LIMIT} error_type: Test Results + - id: '10046' + test_id: 1037 + test_type: Weekly_Rec_Ct + sql_flavor: salesforce_data360 + lookup_type: null + lookup_query: |- + WITH RECURSIVE daterange(all_dates) AS (SELECT CAST(DATE_TRUNC('week', MIN("{COLUMN_NAME}")) AS DATE) AS all_dates FROM "{TABLE_NAME}" UNION ALL SELECT CAST((d.all_dates + INTERVAL '1 week' ) AS DATE) AS all_dates FROM daterange d WHERE d.all_dates < (SELECT CAST(DATE_TRUNC('week' , MAX("{COLUMN_NAME}")) AS DATE) FROM "{TABLE_NAME}") ), existing_periods AS (SELECT DISTINCT CAST(DATE_TRUNC('week', "{COLUMN_NAME}") AS DATE) AS period, COUNT(1) as period_count FROM "{TABLE_NAME}" GROUP BY CAST(DATE_TRUNC('week', "{COLUMN_NAME}") AS DATE)) SELECT p.missing_period, p.prior_available_week, e.period_count AS prior_available_week_count, p.next_available_week, f.period_count AS next_available_week_count FROM (SELECT d.all_dates AS missing_period, MAX(b.period) AS prior_available_week, MIN(c.period) AS next_available_week FROM daterange d LEFT JOIN existing_periods a ON d.all_dates = a.period LEFT JOIN existing_periods b ON b.period < d.all_dates LEFT JOIN existing_periods c ON c.period > d.all_dates WHERE a.period IS NULL AND d.all_dates BETWEEN b.period AND c.period GROUP BY d.all_dates) p LEFT JOIN existing_periods e ON (p.prior_available_week = e.period) LEFT JOIN existing_periods f ON (p.next_available_week = f.period) ORDER BY p.missing_period LIMIT {LIMIT}; + error_type: Test Results test_templates: [] diff --git a/testgen/template/dbupgrade/0190_incremental_upgrade.sql b/testgen/template/dbupgrade/0190_incremental_upgrade.sql new file mode 100644 index 00000000..a44bf9af --- /dev/null +++ b/testgen/template/dbupgrade/0190_incremental_upgrade.sql @@ -0,0 +1,4 @@ +SET SEARCH_PATH TO {SCHEMA_NAME}; + +-- Widen project_user to accommodate Salesforce Data 360 Consumer Keys (86+ chars) +ALTER TABLE connections ALTER COLUMN project_user TYPE VARCHAR(256); diff --git a/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Dupe_Rows.sql b/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Dupe_Rows.sql new file mode 100644 index 00000000..8ad665ae --- /dev/null +++ b/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Dupe_Rows.sql @@ -0,0 +1,55 @@ +WITH latest_run AS ( + -- Latest complete profiling run before as-of-date + SELECT MAX(run_date) AS last_run_date + FROM profile_results + WHERE table_groups_id = :TABLE_GROUPS_ID ::UUID + AND run_date::DATE <= :AS_OF_DATE ::DATE +), +selected_tables AS ( + SELECT profile_run_id, schema_name, table_name, + STRING_AGG(:QUOTE || column_name || :QUOTE, ', ' ORDER BY position) AS groupby_names + FROM profile_results p + INNER JOIN latest_run lr ON p.run_date = lr.last_run_date + WHERE table_groups_id = :TABLE_GROUPS_ID ::UUID + -- Skip X types - Hyper does not support GROUP BY on JSON columns (and VARBINARY by extension) + AND general_type <> 'X' + GROUP BY profile_run_id, schema_name, table_name +) +INSERT INTO test_definitions ( + table_groups_id, test_suite_id, test_type, + schema_name, table_name, + test_active, last_auto_gen_date, profiling_as_of_date, profile_run_id, + groupby_names, skip_errors +) +SELECT + :TABLE_GROUPS_ID ::UUID AS table_groups_id, + :TEST_SUITE_ID ::UUID AS test_suite_id, + 'Dupe_Rows' AS test_type, + s.schema_name, + s.table_name, + 'Y' AS test_active, + :RUN_DATE ::TIMESTAMP AS last_auto_gen_date, + :AS_OF_DATE ::TIMESTAMP AS profiling_as_of_date, + s.profile_run_id, + s.groupby_names, + 0 AS skip_errors +FROM selected_tables s + -- Only insert if test type is active +WHERE EXISTS (SELECT 1 FROM test_types WHERE test_type = 'Dupe_Rows' AND active = 'Y') + -- Only insert if test type is included in generation set + AND EXISTS (SELECT 1 FROM generation_sets WHERE test_type = 'Dupe_Rows' AND generation_set = :GENERATION_SET) + +-- Match "uix_td_autogen_table" unique index exactly +ON CONFLICT (test_suite_id, test_type, schema_name, table_name) +WHERE last_auto_gen_date IS NOT NULL + AND table_name IS NOT NULL + AND column_name IS NULL + +-- Update tests if they already exist +DO UPDATE SET + test_active = EXCLUDED.test_active, + last_auto_gen_date = EXCLUDED.last_auto_gen_date, + groupby_names = EXCLUDED.groupby_names, + skip_errors = EXCLUDED.skip_errors +-- Ignore locked tests +WHERE test_definitions.lock_refresh = 'N'; diff --git a/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Freshness_Trend.sql b/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Freshness_Trend.sql new file mode 100644 index 00000000..8a8f1ec7 --- /dev/null +++ b/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Freshness_Trend.sql @@ -0,0 +1,210 @@ +WITH latest_run AS ( + -- Latest complete profiling run before as-of-date + SELECT MAX(run_date) AS last_run_date + FROM profile_results + WHERE table_groups_id = :TABLE_GROUPS_ID ::UUID + AND run_date::DATE <= :AS_OF_DATE ::DATE +), +latest_results AS ( + -- Column results for latest run + SELECT p.profile_run_id, p.schema_name, p.table_name, p.column_name, + p.functional_data_type, p.general_type, + p.distinct_value_ct, p.record_ct, p.null_value_ct, + p.max_value, p.min_value, p.avg_value, p.stdev_value + FROM profile_results p + INNER JOIN latest_run lr ON p.run_date = lr.last_run_date + INNER JOIN data_table_chars dtc ON ( + dtc.table_groups_id = p.table_groups_id + AND dtc.schema_name = p.schema_name + AND dtc.table_name = p.table_name + -- Ignore dropped tables + AND dtc.drop_date IS NULL + ) + INNER JOIN data_column_chars dcc ON ( + dcc.table_groups_id = p.table_groups_id + AND dcc.schema_name = p.schema_name + AND dcc.table_name = p.table_name + AND dcc.column_name = p.column_name + -- Ignore dropped columns + AND dcc.drop_date IS NULL + ) + WHERE p.table_groups_id = :TABLE_GROUPS_ID ::UUID +), +-- IDs - TOP 2 +id_cols AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + functional_data_type, general_type, distinct_value_ct, + ROW_NUMBER() OVER ( + PARTITION BY schema_name, table_name + ORDER BY + CASE + WHEN functional_data_type ILIKE 'ID-Unique%' THEN 1 + WHEN functional_data_type = 'ID-Secondary' THEN 2 + ELSE 3 + END, distinct_value_ct DESC, column_name + ) AS rank + FROM latest_results + WHERE general_type IN ('A', 'D', 'N') + AND functional_data_type ILIKE 'ID%' +), +-- Process Date - TOP 1 +process_date_cols AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + functional_data_type, general_type, distinct_value_ct, + ROW_NUMBER() OVER ( + PARTITION BY schema_name, table_name + ORDER BY + CASE + WHEN column_name ILIKE '%mod%' THEN 1 + WHEN column_name ILIKE '%up%' THEN 1 + WHEN column_name ILIKE '%cr%' THEN 2 + WHEN column_name ILIKE '%in%' THEN 2 + END, distinct_value_ct DESC, column_name + ) AS rank + FROM latest_results + WHERE general_type IN ('A', 'D', 'N') + AND functional_data_type ILIKE 'process%' +), +-- Transaction Date - TOP 1 +tran_date_cols AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + functional_data_type, general_type, distinct_value_ct, + ROW_NUMBER() OVER ( + PARTITION BY schema_name, table_name + ORDER BY distinct_value_ct DESC, column_name + ) AS rank + FROM latest_results + WHERE general_type IN ('A', 'D', 'N') + AND functional_data_type ILIKE 'transactional date%' + OR functional_data_type ILIKE 'period%' + OR functional_data_type = 'timestamp' +), +-- Numeric Measures +numeric_cols AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + functional_data_type, general_type, +/* + -- Subscores + distinct_value_ct * 1.0 / NULLIF(record_ct, 0) AS cardinality_score, + (max_value - min_value) / NULLIF(ABS(NULLIF(avg_value, 0)), 1) AS range_score, + LEAST(1, LOG(GREATEST(distinct_value_ct, 2))) / LOG(GREATEST(record_ct, 2)) AS nontriviality_score, + stdev_value / NULLIF(ABS(NULLIF(avg_value, 0)), 1) AS variability_score, + 1.0 - (null_value_ct * 1.0 / NULLIF(NULLIF(record_ct, 0), 1)) AS null_penalty, +*/ + -- Weighted score + ( + 0.25 * (distinct_value_ct * 1.0 / NULLIF(record_ct, 0)) + + 0.15 * ((max_value - min_value) / NULLIF(ABS(NULLIF(avg_value, 0)), 1)) + + 0.10 * (LEAST(1, LOG(GREATEST(distinct_value_ct, 2))) / LOG(GREATEST(record_ct, 2))) + + 0.40 * (stdev_value / NULLIF(ABS(NULLIF(avg_value, 0)), 1)) + + 0.10 * (1.0 - (null_value_ct * 1.0 / NULLIF(NULLIF(record_ct, 0), 1))) + ) AS change_detection_score + FROM latest_results + WHERE general_type = 'N' + AND ( + functional_data_type ILIKE 'Measure%' + OR functional_data_type IN ('Sequence', 'Constant') + ) +), +numeric_cols_ranked AS ( + SELECT *, + ROW_NUMBER() OVER ( + PARTITION BY schema_name, table_name + ORDER BY change_detection_score DESC, column_name + ) AS rank + FROM numeric_cols + WHERE change_detection_score IS NOT NULL +), +combined AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + 'ID' AS element_type, general_type, 10 + rank AS fingerprint_order + FROM id_cols + WHERE rank <= 2 + UNION ALL + SELECT profile_run_id, schema_name, table_name, column_name, + 'DATE_P' AS element_type, general_type, 20 + rank AS fingerprint_order + FROM process_date_cols + WHERE rank = 1 + UNION ALL + SELECT profile_run_id, schema_name, table_name, column_name, + 'DATE_T' AS element_type, general_type, 30 + rank AS fingerprint_order + FROM tran_date_cols + WHERE rank = 1 + UNION ALL + SELECT profile_run_id, schema_name, table_name, column_name, + 'MEAS' AS element_type, general_type, 40 + rank AS fingerprint_order + FROM numeric_cols_ranked + WHERE rank = 1 +), +selected_tables AS ( + SELECT profile_run_id, schema_name, table_name, + STRING_AGG(column_name, ',' ORDER BY element_type, fingerprint_order, column_name) AS column_names, + 'CAST(COUNT(*) AS VARCHAR) || ''|'' || ' || + STRING_AGG( + REPLACE( + CASE + WHEN general_type = 'D' THEN 'CAST(MIN(@@@) AS VARCHAR) || ''|'' || CAST(MAX(@@@) AS VARCHAR) || ''|'' || CAST(COUNT(DISTINCT @@@) AS VARCHAR)' + WHEN general_type = 'A' THEN 'CAST(MIN(@@@) AS VARCHAR) || ''|'' || CAST(MAX(@@@) AS VARCHAR) || ''|'' || CAST(COUNT(DISTINCT @@@) AS VARCHAR) || ''|'' || CAST(SUM(LENGTH(@@@)) AS VARCHAR)' + WHEN general_type = 'N' THEN 'CAST(COUNT(@@@) AS VARCHAR) || ''|'' || + CAST(COUNT(DISTINCT MOD(CAST(CAST(COALESCE(@@@,0) AS DECIMAL(38,6)) * 1000000 AS DECIMAL(38,0)), 1000003)) AS VARCHAR) || ''|'' || + COALESCE(CAST(CAST(MIN(@@@) AS DECIMAL(38,6)) AS VARCHAR), '''') || ''|'' || + COALESCE(CAST(CAST(MAX(@@@) AS DECIMAL(38,6)) AS VARCHAR), '''') || ''|'' || + COALESCE(CAST(MOD(COALESCE(SUM(MOD(CAST(ABS(COALESCE(@@@,0)) AS DECIMAL(38,6)) * 1000000, 1000000007)), 0), 1000000007) AS VARCHAR), '''') || ''|'' || + COALESCE(CAST(MOD(COALESCE(SUM(MOD(CAST(ABS(COALESCE(@@@,0)) AS DECIMAL(38,6)) * 1000000, 1000000009)), 0), 1000000009) AS VARCHAR), '''')' + END, + '@@@', '"' || column_name || '"' + ), + ' || ''|'' || ' + ORDER BY element_type, fingerprint_order, column_name + ) AS fingerprint + FROM combined + GROUP BY profile_run_id, schema_name, table_name +) +-- Insert tests for selected tables +INSERT INTO test_definitions ( + table_groups_id, test_suite_id, test_type, + schema_name, table_name, groupby_names, + test_active, last_auto_gen_date, profiling_as_of_date, profile_run_id, + history_calculation, history_lookback, custom_query +) +SELECT + :TABLE_GROUPS_ID ::UUID AS table_groups_id, + :TEST_SUITE_ID ::UUID AS test_suite_id, + 'Freshness_Trend' AS test_type, + s.schema_name, + s.table_name, + s.column_names AS groupby_names, + 'Y' AS test_active, + :RUN_DATE ::TIMESTAMP AS last_auto_gen_date, + :AS_OF_DATE ::TIMESTAMP AS profiling_as_of_date, + s.profile_run_id, + 'PREDICT' AS history_calculation, + NULL AS history_lookback, + s.fingerprint AS custom_query +FROM selected_tables s + -- Only insert if test type is active +WHERE EXISTS (SELECT 1 FROM test_types WHERE test_type = 'Freshness_Trend' AND active = 'Y') + -- Only insert if test type is included in generation set + AND EXISTS (SELECT 1 FROM generation_sets WHERE test_type = 'Freshness_Trend' AND generation_set = :GENERATION_SET) + {TABLE_FILTER} + +-- Match "uix_td_autogen_table" unique index exactly +ON CONFLICT (test_suite_id, test_type, schema_name, table_name) +WHERE last_auto_gen_date IS NOT NULL + AND table_name IS NOT NULL + AND column_name IS NULL + +-- Update tests if they already exist +DO UPDATE SET + groupby_names = EXCLUDED.groupby_names, + test_active = EXCLUDED.test_active, + last_auto_gen_date = EXCLUDED.last_auto_gen_date, + profiling_as_of_date = EXCLUDED.profiling_as_of_date, + profile_run_id = EXCLUDED.profile_run_id, + history_calculation = EXCLUDED.history_calculation, + history_lookback = EXCLUDED.history_lookback, + custom_query = EXCLUDED.custom_query +-- Ignore locked tests +WHERE test_definitions.lock_refresh = 'N' + -- Don't update existing tests in "insert" mode + AND NOT COALESCE(:INSERT_ONLY, FALSE); diff --git a/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Table_Freshness.sql b/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Table_Freshness.sql new file mode 100644 index 00000000..0fca64f2 --- /dev/null +++ b/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Table_Freshness.sql @@ -0,0 +1,189 @@ +WITH latest_run AS ( + -- Latest complete profiling run before as-of-date + SELECT MAX(run_date) AS last_run_date + FROM profile_results + WHERE table_groups_id = :TABLE_GROUPS_ID ::UUID + AND run_date::DATE <= :AS_OF_DATE ::DATE +), +latest_results AS ( + -- Column results for latest run + SELECT profile_run_id, schema_name, table_name, column_name, + functional_data_type, general_type, + distinct_value_ct, record_ct, null_value_ct, + max_value, min_value, avg_value, stdev_value + FROM profile_results p + INNER JOIN latest_run lr ON p.run_date = lr.last_run_date + WHERE table_groups_id = :TABLE_GROUPS_ID ::UUID +), +-- IDs - TOP 2 +id_cols AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + functional_data_type, general_type, distinct_value_ct, + ROW_NUMBER() OVER ( + PARTITION BY schema_name, table_name + ORDER BY + CASE + WHEN functional_data_type ILIKE 'ID-Unique%' THEN 1 + WHEN functional_data_type = 'ID-Secondary' THEN 2 + ELSE 3 + END, distinct_value_ct DESC, column_name + ) AS rank + FROM latest_results + WHERE general_type IN ('A', 'D', 'N') + AND functional_data_type ILIKE 'ID%' +), +-- Process Date - TOP 1 +process_date_cols AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + functional_data_type, general_type, distinct_value_ct, + ROW_NUMBER() OVER ( + PARTITION BY schema_name, table_name + ORDER BY + CASE + WHEN column_name ILIKE '%mod%' THEN 1 + WHEN column_name ILIKE '%up%' THEN 1 + WHEN column_name ILIKE '%cr%' THEN 2 + WHEN column_name ILIKE '%in%' THEN 2 + END, distinct_value_ct DESC, column_name + ) AS rank + FROM latest_results + WHERE general_type IN ('A', 'D', 'N') + AND functional_data_type ILIKE 'process%' +), +-- Transaction Date - TOP 1 +tran_date_cols AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + functional_data_type, general_type, distinct_value_ct, + ROW_NUMBER() OVER ( + PARTITION BY schema_name, table_name + ORDER BY distinct_value_ct DESC, column_name + ) AS rank + FROM latest_results + WHERE general_type IN ('A', 'D', 'N') + AND functional_data_type ILIKE 'transactional date%' + OR functional_data_type ILIKE 'period%' + OR functional_data_type = 'timestamp' +), +-- Numeric Measures +numeric_cols AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + functional_data_type, general_type, +/* + -- Subscores + distinct_value_ct * 1.0 / NULLIF(record_ct, 0) AS cardinality_score, + (max_value - min_value) / NULLIF(ABS(NULLIF(avg_value, 0)), 1) AS range_score, + LEAST(1, LOG(GREATEST(distinct_value_ct, 2))) / LOG(GREATEST(record_ct, 2)) AS nontriviality_score, + stdev_value / NULLIF(ABS(NULLIF(avg_value, 0)), 1) AS variability_score, + 1.0 - (null_value_ct * 1.0 / NULLIF(NULLIF(record_ct, 0), 1)) AS null_penalty, +*/ + -- Weighted score + ( + 0.25 * (distinct_value_ct * 1.0 / NULLIF(record_ct, 0)) + + 0.15 * ((max_value - min_value) / NULLIF(ABS(NULLIF(avg_value, 0)), 1)) + + 0.10 * (LEAST(1, LOG(GREATEST(distinct_value_ct, 2))) / LOG(GREATEST(record_ct, 2))) + + 0.40 * (stdev_value / NULLIF(ABS(NULLIF(avg_value, 0)), 1)) + + 0.10 * (1.0 - (null_value_ct * 1.0 / NULLIF(NULLIF(record_ct, 0), 1))) + ) AS change_detection_score + FROM latest_results + WHERE general_type = 'N' + AND ( + functional_data_type ILIKE 'Measure%' + OR functional_data_type IN ('Sequence', 'Constant') + ) +), +numeric_cols_ranked AS ( + SELECT *, + ROW_NUMBER() OVER ( + PARTITION BY schema_name, table_name + ORDER BY change_detection_score DESC, column_name + ) AS rank + FROM numeric_cols + WHERE change_detection_score IS NOT NULL +), +combined AS ( + SELECT profile_run_id, schema_name, table_name, column_name, + 'ID' AS element_type, general_type, 10 + rank AS fingerprint_order + FROM id_cols + WHERE rank <= 2 + UNION ALL + SELECT profile_run_id, schema_name, table_name, column_name, + 'DATE_P' AS element_type, general_type, 20 + rank AS fingerprint_order + FROM process_date_cols + WHERE rank = 1 + UNION ALL + SELECT profile_run_id, schema_name, table_name, column_name, + 'DATE_T' AS element_type, general_type, 30 + rank AS fingerprint_order + FROM tran_date_cols + WHERE rank = 1 + UNION ALL + SELECT profile_run_id, schema_name, table_name, column_name, + 'MEAS' AS element_type, general_type, 40 + rank AS fingerprint_order + FROM numeric_cols_ranked + WHERE rank = 1 +), +selected_tables AS ( + SELECT profile_run_id, schema_name, table_name, + 'CAST(COUNT(*) AS VARCHAR) || ''|'' || ' || + STRING_AGG( + REPLACE( + CASE + WHEN general_type = 'D' THEN 'CAST(MIN(@@@) AS VARCHAR) || ''|'' || CAST(MAX(@@@) AS VARCHAR) || ''|'' || CAST(COUNT(DISTINCT @@@) AS VARCHAR)' + WHEN general_type = 'A' THEN 'CAST(MIN(@@@) AS VARCHAR) || ''|'' || CAST(MAX(@@@) AS VARCHAR) || ''|'' || CAST(COUNT(DISTINCT @@@) AS VARCHAR) || ''|'' || CAST(SUM(LENGTH(@@@)) AS VARCHAR)' + WHEN general_type = 'N' THEN 'CAST(COUNT(@@@) AS VARCHAR) || ''|'' || + CAST(COUNT(DISTINCT MOD(CAST(CAST(COALESCE(@@@,0) AS DECIMAL(38,6)) * 1000000 AS DECIMAL(38,0)), 1000003)) AS VARCHAR) || ''|'' || + COALESCE(CAST(CAST(MIN(@@@) AS DECIMAL(38,6)) AS VARCHAR), '''') || ''|'' || + COALESCE(CAST(CAST(MAX(@@@) AS DECIMAL(38,6)) AS VARCHAR), '''') || ''|'' || + COALESCE(CAST(MOD(COALESCE(SUM(MOD(CAST(ABS(COALESCE(@@@,0)) AS DECIMAL(38,6)) * 1000000, 1000000007)), 0), 1000000007) AS VARCHAR), '''') || ''|'' || + COALESCE(CAST(MOD(COALESCE(SUM(MOD(CAST(ABS(COALESCE(@@@,0)) AS DECIMAL(38,6)) * 1000000, 1000000009)), 0), 1000000009) AS VARCHAR), '''')' + END, + '@@@', '"' || column_name || '"' + ), + ' || ''|'' || ' + ORDER BY element_type, fingerprint_order, column_name + ) AS fingerprint + FROM combined + GROUP BY profile_run_id, schema_name, table_name +) +-- Insert tests for selected tables +INSERT INTO test_definitions ( + table_groups_id, test_suite_id, test_type, + schema_name, table_name, + test_active, last_auto_gen_date, profiling_as_of_date, profile_run_id, + history_calculation, history_lookback, custom_query +) +SELECT + :TABLE_GROUPS_ID ::UUID AS table_groups_id, + :TEST_SUITE_ID ::UUID AS test_suite_id, + 'Table_Freshness' AS test_type, + s.schema_name, + s.table_name, + 'Y' AS test_active, + :RUN_DATE ::TIMESTAMP AS last_auto_gen_date, + :AS_OF_DATE ::TIMESTAMP AS profiling_as_of_date, + s.profile_run_id, + 'Value' AS history_calculation, + 1 AS history_lookback, + s.fingerprint AS custom_query +FROM selected_tables s + -- Only insert if test type is active +WHERE EXISTS (SELECT 1 FROM test_types WHERE test_type = 'Table_Freshness' AND active = 'Y') + -- Only insert if test type is included in generation set + AND EXISTS (SELECT 1 FROM generation_sets WHERE test_type = 'Table_Freshness' AND generation_set = :GENERATION_SET) + +-- Match "uix_td_autogen_table" unique index exactly +ON CONFLICT (test_suite_id, test_type, schema_name, table_name) +WHERE last_auto_gen_date IS NOT NULL + AND table_name IS NOT NULL + AND column_name IS NULL + +-- Update tests if they already exist +DO UPDATE SET + test_active = EXCLUDED.test_active, + last_auto_gen_date = EXCLUDED.last_auto_gen_date, + profiling_as_of_date = EXCLUDED.profiling_as_of_date, + profile_run_id = EXCLUDED.profile_run_id, + history_calculation = EXCLUDED.history_calculation, + history_lookback = EXCLUDED.history_lookback, + custom_query = EXCLUDED.custom_query +-- Ignore locked tests +WHERE test_definitions.lock_refresh = 'N'; diff --git a/testgen/template/flavors/salesforce_data360/profiling/project_profiling_query.sql b/testgen/template/flavors/salesforce_data360/profiling/project_profiling_query.sql new file mode 100644 index 00000000..a16aeb1c --- /dev/null +++ b/testgen/template/flavors/salesforce_data360/profiling/project_profiling_query.sql @@ -0,0 +1,247 @@ +WITH target_table AS ( +-- TG-IF do_sample + SELECT * FROM "{DATA_TABLE}" ORDER BY RANDOM() LIMIT {SAMPLE_SIZE} +-- TG-ELSE + SELECT * FROM "{DATA_TABLE}" +-- TG-ENDIF +) +SELECT + {CONNECTION_ID} AS connection_id, + '{PROJECT_CODE}' AS project_code, + '{TABLE_GROUPS_ID}' AS table_groups_id, + '{DATA_SCHEMA}' AS schema_name, + '{RUN_DATE}' AS run_date, + '{DATA_TABLE}' AS table_name, + {COL_POS} AS position, + '{COL_NAME_SANITIZED}' AS column_name, + '{COL_TYPE}' AS column_type, + '{DB_DATA_TYPE}' AS db_data_type, + '{COL_GEN_TYPE}' AS general_type, + COUNT(*) AS record_ct, + COUNT("{COL_NAME}") AS value_ct, + COUNT(DISTINCT "{COL_NAME}") AS distinct_value_ct, + SUM(CASE WHEN "{COL_NAME}" IS NULL THEN 1 ELSE 0 END) AS null_value_ct, +-- TG-IF is_type_ADN + MIN(LENGTH(CAST("{COL_NAME}" AS VARCHAR))) AS min_length, + MAX(LENGTH(CAST("{COL_NAME}" AS VARCHAR))) AS max_length, + AVG(CAST(NULLIF(LENGTH(CAST("{COL_NAME}" AS VARCHAR)), 0) AS DOUBLE)) AS avg_length, +-- TG-ELSE + NULL AS min_length, + NULL AS max_length, + NULL AS avg_length, +-- TG-ENDIF +-- TG-IF is_type_A + SUM(CASE + WHEN REGEXP_LIKE(TRIM("{COL_NAME}"), '^0(\.0*)?$') THEN 1 ELSE 0 + END) AS zero_value_ct, +-- TG-ENDIF +-- TG-IF is_type_N + SUM(CASE WHEN CAST("{COL_NAME}" AS DOUBLE) = 0 THEN 1 ELSE 0 END) AS zero_value_ct, +-- TG-ENDIF +-- TG-IF is_not_A_not_N + NULL AS zero_value_ct, +-- TG-ENDIF +-- TG-IF is_type_A + COUNT(DISTINCT UPPER(REPLACE(REPLACE(REPLACE(REPLACE(REPLACE("{COL_NAME}", ' ', ''), '''', ''), ',', ''), '.', ''), '-', ''))) AS distinct_std_value_ct, + SUM(CASE + WHEN "{COL_NAME}" = '' THEN 1 + ELSE 0 + END) AS zero_length_ct, + SUM( CASE + WHEN "{COL_NAME}" BETWEEN ' !' AND '!' THEN 1 + ELSE 0 + END ) AS lead_space_ct, + SUM( CASE WHEN "{COL_NAME}" LIKE '"%"' OR "{COL_NAME}" LIKE '''%''' THEN 1 ELSE 0 END ) AS quoted_value_ct, + SUM( CASE WHEN REGEXP_LIKE("{COL_NAME}", '[0-9]') THEN 1 ELSE 0 END ) AS includes_digit_ct, + SUM( CASE + WHEN REGEXP_LIKE(LOWER("{COL_NAME}"), '^(\.{1,}|-{1,}|\?{1,}|\s{1,}|0{2,}|9{2,}|x{2,}|z{2,})$') THEN 1 + WHEN LOWER("{COL_NAME}") IN ('blank','error','missing','tbd', + 'n/a','#na','none','null','unknown') THEN 1 + WHEN LOWER("{COL_NAME}") IN ('(blank)','(error)','(missing)','(tbd)', + '(n/a)','(#na)','(none)','(null)','(unknown)') THEN 1 + WHEN LOWER("{COL_NAME}") IN ('[blank]','[error]','[missing]','[tbd]', + '[n/a]','[#na]','[none]','[null]','[unknown]') THEN 1 + ELSE 0 + END ) AS filled_value_ct, + SUBSTR(MIN(NULLIF("{COL_NAME}", '')), 1, 100) AS min_text, + SUBSTR(MAX(NULLIF("{COL_NAME}", '')), 1, 100) AS max_text, + SUM(CASE + WHEN REGEXP_REPLACE("{COL_NAME}", '[A-Za-z]', '', 'g') = "{COL_NAME}" THEN 0 + WHEN REGEXP_REPLACE("{COL_NAME}", '[a-z]', '', 'g') = "{COL_NAME}" THEN 1 + ELSE 0 + END) AS upper_case_ct, + SUM(CASE + WHEN REGEXP_REPLACE("{COL_NAME}", '[A-Za-z]', '', 'g') = "{COL_NAME}" THEN 0 + WHEN REGEXP_REPLACE("{COL_NAME}", '[A-Z]', '', 'g') = "{COL_NAME}" THEN 1 + ELSE 0 + END) AS lower_case_ct, + SUM(CASE + WHEN REGEXP_REPLACE("{COL_NAME}", '[A-Za-z]', '', 'g') = "{COL_NAME}" THEN 1 + ELSE 0 + END) AS non_alpha_ct, + SUM(CASE WHEN REGEXP_REPLACE("{COL_NAME}", + '[' || CHR(160) || CHR(8201) || CHR(8203) || CHR(8204) || CHR(8205) || CHR(8206) || CHR(8207) || CHR(8239) || CHR(12288) || CHR(65279) || ']', + 'X', 'g') <> "{COL_NAME}" THEN 1 ELSE 0 END) AS non_printing_ct, + SUM(<%IS_NUM;SUBSTR("{COL_NAME}", 1, 31)%>) AS numeric_ct, + SUM(<%IS_DATE;SUBSTR("{COL_NAME}", 1, 26)%>) AS date_ct, + CASE + WHEN SUM( CASE WHEN REGEXP_LIKE("{COL_NAME}", '^[0-9]{1,5}[a-zA-Z]?\s\w{1,5}\.?\s?\w*\s?\w*\s[a-zA-Z]{1,6}\.?\s?[0-9]{0,5}[A-Z]{0,1}$') + THEN 1 END) > CAST(0.8 * COUNT("{COL_NAME}") AS BIGINT) THEN 'STREET_ADDR' + WHEN SUM(CASE WHEN "{COL_NAME}" IN ('AL','AK','AS','AZ','AR','CA','CO','CT','DE','DC','FM','FL','GA','GU','HI','ID','IL','IN','IA','KS','KY','LA','ME','MH','MD','MA','MI','MN','MS','MO','MT','NE','NV','NH','NJ','NM','NY','NC','ND','MP','OH','OK','OR','PW','PA','PR','RI','SC','SD','TN','TX','UT','VT','VI','VA','WA','WV','WI','WY','AE','AP','AA') + THEN 1 END) > CAST(0.9 * COUNT("{COL_NAME}") AS BIGINT) THEN 'STATE_USA' + WHEN SUM( CASE WHEN REGEXP_LIKE("{COL_NAME}", '^(\+1|1)?[ .\-]?(\([2-9][0-9]{2}\)|[2-9][0-9]{2})[ .\-]?[2-9][0-9]{2}[ .\-]?[0-9]{4}$') + THEN 1 END) > CAST(0.8 * COUNT("{COL_NAME}") AS BIGINT) THEN 'PHONE_USA' + WHEN SUM( CASE WHEN REGEXP_LIKE("{COL_NAME}", '^[A-Za-z0-9._%+\-]+@[A-Za-z0-9.\-]+\.[A-Za-z]{2,}$') + AND "{COL_NAME}" NOT LIKE '%://%' + THEN 1 END) > CAST(0.9 * COUNT("{COL_NAME}") AS BIGINT) THEN 'EMAIL' + WHEN SUM( CASE WHEN REGEXP_LIKE(REGEXP_REPLACE("{COL_NAME}", '[0-9]', '9', 'g'), '^(99999|999999999|99999-9999)$') + THEN 1 END) > CAST(0.9 * COUNT("{COL_NAME}") AS BIGINT) THEN 'ZIP_USA' + WHEN SUM( CASE WHEN REGEXP_LIKE("{COL_NAME}", '^[\w\s\-]+\.(txt|csv|tsv|dat|doc|pdf|xlsx)$') + THEN 1 END) > CAST(0.9 * COUNT("{COL_NAME}") AS BIGINT) THEN 'FILE_NAME' + WHEN SUM( CASE WHEN REGEXP_LIKE("{COL_NAME}", '^([0-9]{4}[- ]?){3}[0-9]{4}$') + THEN 1 END) > CAST(0.8 * COUNT("{COL_NAME}") AS BIGINT) THEN 'CREDIT_CARD' + WHEN SUM( CASE WHEN REGEXP_LIKE("{COL_NAME}", '^([^,|\t]{1,20}[,|\t]){2,}[^,|\t]{0,20}([,|\t]{0,1}[^,|\t]{0,20})*$') + AND NOT REGEXP_LIKE("{COL_NAME}", '\s(and|but|or|yet)\s') + THEN 1 END) > CAST(0.8 * COUNT("{COL_NAME}") AS BIGINT) THEN 'DELIMITED_DATA' + WHEN SUM ( CASE WHEN REGEXP_LIKE("{COL_NAME}", '^[0-8][0-9]{2}-[0-9]{2}-[0-9]{4}$') + AND SUBSTR("{COL_NAME}", 1, 3) NOT BETWEEN '734' AND '749' + AND SUBSTR("{COL_NAME}", 1, 3) <> '666' THEN 1 END) > CAST(0.9 * COUNT("{COL_NAME}") AS BIGINT) THEN 'SSN' + END AS std_pattern_match, +-- TG-ELSE + NULL AS distinct_std_value_ct, + NULL AS zero_length_ct, + NULL AS lead_space_ct, + NULL AS quoted_value_ct, + NULL AS includes_digit_ct, + NULL AS filled_value_ct, + NULL AS min_text, + NULL AS max_text, + NULL AS upper_case_ct, + NULL AS lower_case_ct, + NULL AS non_alpha_ct, + NULL AS non_printing_ct, + NULL AS numeric_ct, + NULL AS date_ct, + NULL AS std_pattern_match, +-- TG-ENDIF +-- TG-IF is_type_A + (SELECT SUBSTR(ARRAY_JOIN(ARRAY_AGG(pattern), ' | '), 1, 1000) AS concat_pats + FROM ( + SELECT CAST(COUNT(*) AS VARCHAR) || ' | ' || pattern AS pattern, + COUNT(*) AS ct + FROM ( SELECT REGEXP_REPLACE(REGEXP_REPLACE( REGEXP_REPLACE( + "{COL_NAME}", '[a-z]', 'a', 'g'), + '[A-Z]', 'A', 'g'), + '[0-9]', 'N', 'g') AS pattern + FROM target_table + WHERE "{COL_NAME}" > ' ' AND (SELECT MAX(LENGTH("{COL_NAME}")) + FROM target_table) BETWEEN 3 and {MAX_PATTERN_LENGTH}) p + GROUP BY pattern + HAVING pattern > ' ' + ORDER BY COUNT(*) DESC + LIMIT 5 + ) ps) AS top_patterns, +-- TG-ELSE + NULL AS top_patterns, +-- TG-ENDIF +-- TG-IF is_type_N + MIN("{COL_NAME}") AS min_value, + MIN(CASE WHEN CAST("{COL_NAME}" AS DOUBLE) > 0 THEN "{COL_NAME}" ELSE NULL END) AS min_value_over_0, + MAX("{COL_NAME}") AS max_value, + AVG(CAST("{COL_NAME}" AS DOUBLE)) AS avg_value, + STDDEV(CAST("{COL_NAME}" AS DOUBLE)) AS stdev_value, + APPROX_PERCENTILE(CAST("{COL_NAME}" AS DOUBLE), 0.25) AS percentile_25, + APPROX_PERCENTILE(CAST("{COL_NAME}" AS DOUBLE), 0.50) AS percentile_50, + APPROX_PERCENTILE(CAST("{COL_NAME}" AS DOUBLE), 0.75) AS percentile_75, +-- TG-ELSE + NULL AS min_value, + NULL AS min_value_over_0, + NULL AS max_value, + NULL AS avg_value, + NULL AS stdev_value, + NULL AS percentile_25, + NULL AS percentile_50, + NULL AS percentile_75, +-- TG-ENDIF +-- TG-IF is_N_decimal + SUM(ROUND(ABS(MOD(CAST("{COL_NAME}" AS DOUBLE), 1)), 5)) AS fractional_sum, +-- TG-ELSE + NULL AS fractional_sum, +-- TG-ENDIF +-- TG-IF is_type_D + CASE + WHEN MIN("{COL_NAME}") IS NULL THEN NULL + ELSE GREATEST(MIN("{COL_NAME}"), CAST('0001-01-01' AS TIMESTAMP)) + END AS min_date, + MAX("{COL_NAME}") AS max_date, + SUM(CASE + WHEN DATEDIFF('month', CAST("{COL_NAME}" AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP)) > 12 THEN 1 + ELSE 0 + END) AS before_1yr_date_ct, + SUM(CASE + WHEN DATEDIFF('month', CAST("{COL_NAME}" AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP)) > 60 THEN 1 + ELSE 0 + END) AS before_5yr_date_ct, + SUM(CASE + WHEN DATEDIFF('month', CAST("{COL_NAME}" AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP)) > 240 THEN 1 + ELSE 0 + END) AS before_20yr_date_ct, + SUM(CASE + WHEN DATEDIFF('month', CAST("{COL_NAME}" AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP)) > 1200 THEN 1 + ELSE 0 + END) AS before_100yr_date_ct, + SUM(CASE + WHEN DATEDIFF('day', CAST("{COL_NAME}" AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP)) BETWEEN 0 AND 365 THEN 1 + ELSE 0 + END) AS within_1yr_date_ct, + SUM(CASE + WHEN DATEDIFF('day', CAST("{COL_NAME}" AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP)) BETWEEN 0 AND 30 THEN 1 + ELSE 0 + END) AS within_1mo_date_ct, + SUM(CASE + WHEN "{COL_NAME}" > CAST('{RUN_DATE}' AS TIMESTAMP) THEN 1 ELSE 0 + END) AS future_date_ct, + SUM(CASE + WHEN DATEDIFF('month', CAST('{RUN_DATE}' AS TIMESTAMP), CAST("{COL_NAME}" AS TIMESTAMP)) > 240 THEN 1 + ELSE 0 + END) AS distant_future_date_ct, + COUNT(DISTINCT DATEDIFF('day', CAST("{COL_NAME}" AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP))) AS date_days_present, + COUNT(DISTINCT DATEDIFF('week', CAST("{COL_NAME}" AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP))) AS date_weeks_present, + COUNT(DISTINCT DATEDIFF('month', CAST("{COL_NAME}" AS TIMESTAMP), CAST('{RUN_DATE}' AS TIMESTAMP))) AS date_months_present, +-- TG-ELSE + NULL AS min_date, + NULL AS max_date, + NULL AS before_1yr_date_ct, + NULL AS before_5yr_date_ct, + NULL AS before_20yr_date_ct, + NULL AS before_100yr_date_ct, + NULL AS within_1yr_date_ct, + NULL AS within_1mo_date_ct, + NULL AS future_date_ct, + NULL AS distant_future_date_ct, + NULL AS date_days_present, + NULL AS date_weeks_present, + NULL AS date_months_present, +-- TG-ENDIF +-- TG-IF is_type_B + SUM(CAST("{COL_NAME}" AS INTEGER)) AS boolean_true_ct, +-- TG-ELSE + NULL AS boolean_true_ct, +-- TG-ENDIF +-- TG-IF is_type_A + (SELECT COUNT(DISTINCT REGEXP_REPLACE( REGEXP_REPLACE( REGEXP_REPLACE( + "{COL_NAME}", '[a-z]', 'a', 'g'), + '[A-Z]', 'A', 'g'), + '[0-9]', 'N', 'g') + ) AS pattern_ct + FROM target_table + WHERE "{COL_NAME}" > ' ' ) AS distinct_pattern_ct, + SUM(CASE WHEN LENGTH(TRIM("{COL_NAME}")) - LENGTH(REGEXP_REPLACE(TRIM("{COL_NAME}"), ' ', '', 'g')) > 0 THEN 1 ELSE 0 END) AS embedded_space_ct, + AVG(CAST(LENGTH(TRIM("{COL_NAME}")) - LENGTH(REGEXP_REPLACE(TRIM("{COL_NAME}"), ' ', '', 'g')) AS DOUBLE)) AS avg_embedded_spaces, +-- TG-ELSE + NULL AS distinct_pattern_ct, + NULL AS embedded_space_ct, + NULL AS avg_embedded_spaces, +-- TG-ENDIF + '{PROFILE_RUN_ID}' AS profile_run_id + FROM target_table diff --git a/testgen/template/flavors/salesforce_data360/profiling/project_secondary_profiling_query.sql b/testgen/template/flavors/salesforce_data360/profiling/project_secondary_profiling_query.sql new file mode 100644 index 00000000..3e575d78 --- /dev/null +++ b/testgen/template/flavors/salesforce_data360/profiling/project_secondary_profiling_query.sql @@ -0,0 +1,37 @@ +-- Get Freqs for selected columns +WITH target_table AS ( + SELECT * FROM "{DATA_TABLE}" +-- TG-IF do_sample_bool + ORDER BY RANDOM() LIMIT {SAMPLE_SIZE} +-- TG-ENDIF +), +ranked_vals AS ( + SELECT "{COL_NAME}", + COUNT(*) AS ct, + ROW_NUMBER() OVER (ORDER BY COUNT(*) DESC, "{COL_NAME}") AS rn + FROM target_table + WHERE "{COL_NAME}" > ' ' + GROUP BY "{COL_NAME}" +), +consol_vals AS ( + SELECT COALESCE(CASE WHEN rn <= 10 THEN '| ' || "{COL_NAME}" || ' | ' || CAST(ct AS VARCHAR) + ELSE NULL + END, '| Other Values (' || CAST(COUNT(DISTINCT "{COL_NAME}") as VARCHAR) || ') | ' || CAST(SUM(ct) as VARCHAR) ) AS val, + MIN(rn) as min_rn + FROM ranked_vals + GROUP BY CASE WHEN rn <= 10 THEN '| ' || "{COL_NAME}" || ' | ' || CAST(ct AS VARCHAR) + ELSE NULL + END +) +SELECT '{PROJECT_CODE}' as project_code, + '{DATA_SCHEMA}' as schema_name, + '{RUN_DATE}' as run_date, + '{DATA_TABLE}' as table_name, + '{COL_NAME}' as column_name, + REPLACE(ARRAY_JOIN(ARRAY_AGG(val), '^#^'), '^#^', CHR(10)) AS top_freq_values, + ( SELECT MD5(ARRAY_JOIN(ARRAY_AGG(v), '|')) as dvh + FROM (SELECT DISTINCT NULLIF("{COL_NAME}", '') AS v + FROM target_table + WHERE NULLIF("{COL_NAME}", '') IS NOT NULL + ORDER BY v) sorted_vals ) as distinct_value_hash + FROM (SELECT * FROM consol_vals ORDER BY min_rn LIMIT 11) ordered_vals; diff --git a/testgen/template/flavors/salesforce_data360/profiling/templated_functions.yaml b/testgen/template/flavors/salesforce_data360/profiling/templated_functions.yaml new file mode 100644 index 00000000..7ae06f79 --- /dev/null +++ b/testgen/template/flavors/salesforce_data360/profiling/templated_functions.yaml @@ -0,0 +1,98 @@ +IS_NUM: CASE + WHEN REGEXP_LIKE({$1}, '^\s*[+-]?\$?\s*[0-9]+(,[0-9]{3})*(\.[0-9]*)?[%]?\s*$') THEN 1 + ELSE 0 + END + +IS_DATE: CASE + /* YYYY-MM-DD HH:MM:SS SSSSSS or YYYY-MM-DD HH:MM:SS */ + WHEN REGEXP_LIKE({$1}, '^(\d{4})-(0[1-9]|1[0-2])-(0[1-9]|[12][0-9]|3[01])\s(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\s[0-9]{6})?$') + THEN CASE + WHEN CAST(SUBSTR({$1}, 1, 4) AS INTEGER) BETWEEN 1800 AND 2200 + AND ( + ( SUBSTRING ({$1}, 6, 2) IN ('01', '03', '05', '07', '08', + '10', '12') + AND CAST(SUBSTRING ({$1}, 9, 2) AS INTEGER) BETWEEN 1 AND 31 ) + OR ( SUBSTRING ({$1}, 6, 2) IN ('04', '06', '09') + AND CAST(SUBSTRING ({$1}, 9, 2) AS INTEGER) BETWEEN 1 AND 30 ) + OR ( SUBSTRING ({$1}, 6, 2) = '02' + AND CAST(SUBSTRING ({$1}, 9, 2) AS INTEGER) BETWEEN 1 AND 29) + ) + THEN 1 + ELSE 0 + END + /* YYYYMMDDHHMMSSSSSS or YYYYMMDD */ + WHEN REGEXP_LIKE({$1}, '^(\d{4})(0[1-9]|1[0-2])(0[1-9]|[12][0-9]|3[01])(2[0-3]|[01][0-9])([0-5][0-9])([0-5][0-9])([0-9]{6})$') + OR REGEXP_LIKE({$1}, '^(\d{4})(0[1-9]|1[0-2])(0[1-9]|[12][0-9]|3[01])(2[0-3]|[01][0-9])$') + THEN CASE + WHEN CAST(SUBSTR({$1}, 1, 4) AS INTEGER) BETWEEN 1800 AND 2200 + AND ( + ( SUBSTRING({$1}, 5, 2) IN ('01', '03', '05', '07', '08', + '10', '12') + AND CAST(SUBSTRING({$1}, 7, 2) AS INTEGER) BETWEEN 1 AND 31 ) + OR ( SUBSTRING({$1}, 5, 2) IN ('04', '06', '09') + AND CAST(SUBSTRING({$1}, 7, 2) AS INTEGER) BETWEEN 1 AND 30 ) + OR ( SUBSTRING({$1}, 5, 2) = '02' + AND CAST(SUBSTRING({$1}, 7, 2) AS INTEGER) BETWEEN 1 AND 29) + ) + THEN 1 + ELSE 0 + END + /* Exclude anything else long */ + WHEN LENGTH({$1}) > 11 THEN 0 + /* YYYY-MMM/MM-DD */ + WHEN REGEXP_LIKE(REGEXP_REPLACE(UPPER({$1}), '(JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC)', '12', 'g'), + '[12][09][0-9][0-9]-[0-1]?[0-9]-[0-3]?[0-9]') + THEN CASE + WHEN CAST(SPLIT_PART({$1}, '-', 1) AS INTEGER) BETWEEN 1800 AND 2200 + AND ( + ( UPPER(SPLIT_PART({$1}, '-', 2)) IN ('01', '03', '05', '07', '08', + '1', '3', '5', '7', '8', '10', '12', + 'JAN', 'MAR', 'MAY', 'JUL', 'AUG', + 'OCT', 'DEC') + AND CAST(SPLIT_PART({$1}, '-', 3) AS INTEGER) BETWEEN 1 AND 31 ) + OR ( UPPER(SPLIT_PART({$1}, '-', 2)) IN ('04', '06', '09', '4', '6', '9', '11', + 'APR', 'JUN', 'SEP', 'NOV') + AND CAST(SPLIT_PART({$1}, '-', 3) AS INTEGER) BETWEEN 1 AND 30 ) + OR ( UPPER(SPLIT_PART({$1}, '-', 2)) IN ('02', '2', 'FEB') + AND CAST(SPLIT_PART({$1}, '-', 3) AS INTEGER) BETWEEN 1 AND 29) + ) + THEN 1 + ELSE 0 + END + /* MM/-DD/-YY/YYYY */ + WHEN REGEXP_LIKE(REPLACE({$1}, '-', '/'), '^[0-1]?[0-9]/[0-3]?[0-9]/[12][09][0-9][0-9]$') + OR REGEXP_LIKE(REPLACE({$1}, '-', '/'), '^[0-1]?[0-9]/[0-3]?[0-9]/[0-9][0-9]$') + THEN + CASE + WHEN CAST(SPLIT_PART(REPLACE({$1}, '-', '/'), '/', 1) AS INTEGER) BETWEEN 1 AND 12 + AND ( + ( CAST(SPLIT_PART(REPLACE({$1}, '-', '/'), '/', 1) AS INTEGER) IN (1, 3, 5, 7, 8, 10, 12) + AND CAST(SPLIT_PART(REPLACE({$1}, '-', '/'), '/', 2) AS INTEGER) BETWEEN 1 AND 31 ) + OR ( CAST(SPLIT_PART(REPLACE({$1}, '-', '/'), '/', 1) AS INTEGER) IN (4, 6, 9, 11) + AND CAST(SPLIT_PART(REPLACE({$1}, '-', '/'), '/', 2) AS INTEGER) BETWEEN 1 AND 30 ) + OR ( CAST(SPLIT_PART(REPLACE({$1}, '-', '/'), '/', 1) AS INTEGER) = 2 + AND CAST(SPLIT_PART(REPLACE({$1}, '-', '/'), '/', 2) AS INTEGER) BETWEEN 1 AND 29) + ) + AND + CAST('20' || SUBSTRING(SPLIT_PART(REPLACE({$1}, '-', '/'), '/', 3), LENGTH(SPLIT_PART(REPLACE({$1}, '-', '/'), '/', 3)) - 1) AS INTEGER) BETWEEN 1800 AND 2200 + THEN 1 + ELSE 0 + END + /* DD-MMM-YYYY */ + WHEN REGEXP_LIKE(UPPER({$1}), '[0-3]?[0-9]-(JAN|FEB|MAR|APR|MAY|JUN|JUL|AUG|SEP|OCT|NOV|DEC)-[12][09][0-9][0-9]') + THEN + CASE + WHEN CAST(SPLIT_PART({$1}, '-', 3) AS INTEGER) BETWEEN 1800 AND 2200 + AND ( + ( UPPER(SPLIT_PART({$1}, '-', 2)) IN ('JAN', 'MAR', 'MAY', 'JUL', 'AUG', 'OCT', 'DEC') + AND CAST(SPLIT_PART({$1}, '-', 1) AS INTEGER) BETWEEN 1 AND 31 ) + OR ( UPPER(SPLIT_PART({$1}, '-', 2)) IN ('APR', 'JUN', 'SEP', 'NOV') + AND CAST(SPLIT_PART({$1}, '-', 1) AS INTEGER) BETWEEN 1 AND 30 ) + OR ( UPPER(SPLIT_PART({$1}, '-', 2)) = 'FEB' + AND CAST(SPLIT_PART({$1}, '-', 1) AS INTEGER) BETWEEN 1 AND 29) + ) + THEN 1 + ELSE 0 + END + ELSE 0 + END diff --git a/testgen/ui/assets/flavors/salesforce_data360.svg b/testgen/ui/assets/flavors/salesforce_data360.svg new file mode 100644 index 00000000..beacb0d9 --- /dev/null +++ b/testgen/ui/assets/flavors/salesforce_data360.svg @@ -0,0 +1,83 @@ + + + + Salesforce.com logo + A cloud computing company based in San Francisco, California, United States + + + + image/svg+xml + + Salesforce.com logo + + + + + + + + + + + + + + + diff --git a/testgen/ui/components/frontend/js/pages/table_group_list.js b/testgen/ui/components/frontend/js/pages/table_group_list.js index 0e0b819c..3f52a24a 100644 --- a/testgen/ui/components/frontend/js/pages/table_group_list.js +++ b/testgen/ui/components/frontend/js/pages/table_group_list.js @@ -84,7 +84,7 @@ const TableGroupList = (props) => { if (key !== wizardKey) { wizardContainer.innerHTML = ''; wizardKey = key; - van.add(wizardContainer, TableGroupWizard({ emit, + van.add(wizardContainer, TableGroupWizard({ emit, project_code: van.derive(() => getValue(props.wizard)?.project_code), connections: van.derive(() => getValue(props.wizard)?.connections), table_group: van.derive(() => getValue(props.wizard)?.table_group), @@ -115,7 +115,7 @@ const TableGroupList = (props) => { if (key !== editDialogKey) { editDialogContainer.innerHTML = ''; editDialogKey = key; - van.add(editDialogContainer, TableGroupEditDialog({ emit, + van.add(editDialogContainer, TableGroupEditDialog({ emit, dialog: van.derive(() => getValue(props.edit_dialog)?.dialog), connections: van.derive(() => getValue(props.edit_dialog)?.connections), table_group: van.derive(() => getValue(props.edit_dialog)?.table_group), @@ -225,7 +225,7 @@ const TableGroupList = (props) => { { class: 'flex-row fx-gap-3' }, div( { class: 'flex-column fx-flex fx-gap-3' }, - Link({ emit, + Link({ emit, label: 'View test suites', href: 'test-suites', params: { 'project_code': projectSummary.project_code, 'table_group_id': tableGroup.id }, @@ -238,7 +238,7 @@ const TableGroupList = (props) => { { class: 'flex-column fx-flex fx-gap-4' }, div( { class: 'flex-column fx-flex' }, - Caption({content: 'DB Schema', style: 'margin-bottom: 4px;'}), + Caption({content: tableGroup.connection.flavor.flavor === 'salesforce_data360' ? 'Data Space' : 'Schema', style: 'margin-bottom: 4px;'}), span(tableGroup.table_group_schema || '--'), ), div( diff --git a/testgen/ui/components/frontend/js/pages/test_definitions.js b/testgen/ui/components/frontend/js/pages/test_definitions.js index 58c707f5..47bf21d2 100644 --- a/testgen/ui/components/frontend/js/pages/test_definitions.js +++ b/testgen/ui/components/frontend/js/pages/test_definitions.js @@ -854,6 +854,7 @@ const AddDialogComponent = ({ open, info, validateResult: validateResultProp, on const tableGroupsId = van.derive(() => getValue(info)?.table_groups_id ?? ''); const testSuite = van.derive(() => getValue(info)?.test_suite ?? {}); const tableColumns = van.derive(() => getValue(info)?.table_columns ?? []); + const qualifiesTableRefsWithSchema = van.derive(() => getValue(info)?.qualifies_table_refs_with_schema ?? true); const validateResult = van.derive(() => getValue(validateResultProp) ?? null); const scopeFilter = { @@ -960,6 +961,7 @@ const AddDialogComponent = ({ open, info, validateResult: validateResultProp, on formValues: fv, tableColumns: tableColumns.rawVal, testSuite: testSuite.rawVal, + qualifiesTableRefsWithSchema: qualifiesTableRefsWithSchema.rawVal, validateResult: vr, mode: 'add', onFormChange: (changes) => { @@ -979,6 +981,7 @@ const EditDialogComponent = ({ open, info, validateResult: validateResultProp, o const dialogInfo = van.derive(() => getValue(info) ?? null); const tableColumns = van.derive(() => dialogInfo.val?.table_columns ?? []); const testSuite = van.derive(() => dialogInfo.val?.test_suite ?? {}); + const qualifiesTableRefsWithSchema = van.derive(() => dialogInfo.val?.qualifies_table_refs_with_schema ?? true); const validateResult = van.derive(() => getValue(validateResultProp) ?? null); const formValues = van.state(null); @@ -1022,6 +1025,7 @@ const EditDialogComponent = ({ open, info, validateResult: validateResultProp, o formValues: fv, tableColumns: tableColumns.rawVal, testSuite: testSuite.rawVal, + qualifiesTableRefsWithSchema: qualifiesTableRefsWithSchema.rawVal, validateResult: vr, mode: 'edit', onFormChange: (changes) => { @@ -1037,7 +1041,7 @@ const EditDialogComponent = ({ open, info, validateResult: validateResultProp, o }; // Shared form content for add/edit dialogs -const TestDefFormContent = ({ formValues, tableColumns, testSuite, validateResult, mode, onFormChange, onValidate, onSave, onCancel }) => { +const TestDefFormContent = ({ formValues, tableColumns, testSuite, validateResult, mode, qualifiesTableRefsWithSchema, onFormChange, onValidate, onSave, onCancel }) => { const testScope = formValues.test_scope ?? 'column'; const runType = formValues.run_type ?? 'CAT'; const testType = formValues.test_type ?? ''; @@ -1177,12 +1181,14 @@ const TestDefFormContent = ({ formValues, tableColumns, testSuite, validateResul ), // Schema (read-only) - Input({ - name: 'schema_name', - label: 'Schema', - value: formValues.schema_name ?? '', - disabled: true, - }), + qualifiesTableRefsWithSchema + ? Input({ + name: 'schema_name', + label: 'Schema', + value: formValues.schema_name ?? '', + disabled: true, + }) + : null, // Table name testScope !== 'tablegroup' @@ -1242,6 +1248,7 @@ const TestDefFormContent = ({ formValues, tableColumns, testSuite, validateResul { class: 'td-form-params-section' }, TestDefinitionForm({ definition: formValues, + qualifiesTableRefsWithSchema, onChange: (changes) => { if (Object.keys(changes).length === 0) return; const updated = { ...fv.rawVal, ...changes }; diff --git a/testgen/ui/queries/table_group_queries.py b/testgen/ui/queries/table_group_queries.py index 8cc32c73..f52fdf0e 100644 --- a/testgen/ui/queries/table_group_queries.py +++ b/testgen/ui/queries/table_group_queries.py @@ -5,8 +5,10 @@ import streamlit as st -from testgen.commands.queries.refresh_data_chars_query import ColumnChars, RefreshDataCharsSQL +from testgen.commands.queries.refresh_data_chars_query import RefreshDataCharsSQL from testgen.commands.run_refresh_data_chars import write_data_chars +from testgen.common.database.column_chars import ColumnChars +from testgen.common.database.flavor.flavor_service import resolve_connection_params from testgen.common.models.connection import Connection from testgen.common.models.table_group import TableGroup from testgen.ui.services.database_service import fetch_from_target_db @@ -109,8 +111,13 @@ def _get_preview( connection: Connection, ) -> tuple[TableGroupPreview, list[ColumnChars], RefreshDataCharsSQL]: sql_generator = RefreshDataCharsSQL(connection, table_group) - data_chars = fetch_from_target_db(connection, *sql_generator.get_schema_ddf()) - data_chars = [ColumnChars(**column) for column in data_chars] + if sql_generator.flavor_service.metadata_via_api: + params = resolve_connection_params(connection.__dict__) + api_columns = sql_generator.flavor_service.get_schema_columns(params, table_group.table_group_schema) or [] + data_chars = sql_generator.filter_schema_columns(api_columns) + else: + rows = fetch_from_target_db(connection, *sql_generator.get_schema_ddf()) + data_chars = [ColumnChars(**column) for column in rows] preview: TableGroupPreview = { "stats": { diff --git a/testgen/ui/static/js/components/connection_form.js b/testgen/ui/static/js/components/connection_form.js index 0c75655a..0236996f 100644 --- a/testgen/ui/static/js/components/connection_form.js +++ b/testgen/ui/static/js/components/connection_form.js @@ -87,6 +87,9 @@ const defaultPorts = { sap_hana: '39015', }; +// Salesforce Data 360's Hyper engine has a lower expression-depth limit than other databases +const defaultMaxQueryChars = (flavorCode) => flavorCode === 'salesforce_data360' ? 15000 : 20000; + /** * * @param {Properties} props @@ -114,7 +117,7 @@ const ConnectionForm = (props, saveButton) => { const connectionFlavor = van.state(connection?.sql_flavor_code); const connectionName = van.state(connection?.connection_name ?? ''); const connectionMaxThreads = van.state(connection?.max_threads ?? 4); - const connectionQueryChars = van.state(connection?.max_query_chars ?? 20000); + const connectionQueryChars = van.state(connection?.max_query_chars ?? defaultMaxQueryChars(connection?.sql_flavor_code)); const privateKeyFile = van.state(getValue(props.cachedPrivateKeyFile) ?? null); const serviceAccountKeyFile = van.state(getValue(props.cachedServiceAccountKeyFile) ?? null); @@ -139,7 +142,7 @@ const ConnectionForm = (props, saveButton) => { sql_flavor_code: connectionFlavor.rawVal ?? '', connection_name: connectionName.rawVal ?? '', max_threads: connectionMaxThreads.rawVal ?? 4, - max_query_chars: connectionQueryChars.rawVal ?? 20000, + max_query_chars: connectionQueryChars.rawVal ?? defaultMaxQueryChars(connectionFlavor.rawVal), }); const dynamicConnectionUrl = van.state(props.dynamicConnectionUrl?.rawVal ?? ''); @@ -179,6 +182,7 @@ const ConnectionForm = (props, saveButton) => { setFieldValidity('redshift_spectrum_form', isValid); }, connection, + dynamicConnectionUrl, ), azure_mssql: () => AzureMSSQLForm( updatedConnection, @@ -274,6 +278,17 @@ const ConnectionForm = (props, saveButton) => { connection, getValue(props.cachedServiceAccountKeyFile) ?? null ), + salesforce_data360: () => SalesforceData360Form( + updatedConnection, + getValue(props.flavors).find(f => f.value === connectionFlavor.rawVal), + (formValue, fileValue, isValid) => { + updatedConnection.val = {...updatedConnection.val, ...formValue}; + privateKeyFile.val = fileValue; + setFieldValidity('salesforce_data360_form', isValid); + }, + connection, + getValue(props.cachedPrivateKeyFile) ?? null, + ), }; const setFieldValidity = (field, validity) => { @@ -287,17 +302,6 @@ const ConnectionForm = (props, saveButton) => { return authenticationForms[flavor.value](); }); - van.derive(() => { - const selectedFlavorCode = connectionFlavor.val; - const previousFlavorCode = connectionFlavor.oldVal; - const updatedConnection_ = updatedConnection.rawVal; - - const isCustomPort = updatedConnection_?.project_port !== defaultPorts[previousFlavorCode]; - if (selectedFlavorCode !== previousFlavorCode && (!isCustomPort || !updatedConnection_?.project_port)) { - updatedConnection.val = {...updatedConnection_, project_port: defaultPorts[selectedFlavorCode]}; - } - }); - van.derive(() => { const selectedFlavor = connectionFlavor.val; const flavorObject = getValue(props.flavors).find(f => f.value === selectedFlavor); @@ -409,7 +413,6 @@ const ConnectionForm = (props, saveButton) => { /** * @param {VanState} connection * @param {Flavor} flavor - * @param {boolean} maskPassword * @param {(params: Partial, isValid: boolean) => void} onChange * @param {Connection?} originalConnection * @param {VanState} dynamicConnectionUrl @@ -788,7 +791,6 @@ const MSSQLForm = RedshiftForm; /** * @param {VanState} connection * @param {Flavor} flavor - * @param {boolean} maskPassword * @param {(params: Partial, isValid: boolean) => void} onChange * @param {Connection?} originalConnection * @param {VanState} dynamicConnectionUrl @@ -1030,7 +1032,6 @@ const DatabricksForm = ( /** * @param {VanState} connection * @param {Flavor} flavor - * @param {boolean} maskPassword * @param {(params: Partial, fileValue: FileValue, isValid: boolean) => void} onChange * @param {Connection?} originalConnection * @param {string?} cachedFile @@ -1327,6 +1328,162 @@ const SnowflakeForm = ( ); }; +/** + * @param {VanState} connection + * @param {Flavor} flavor + * @param {(params: Partial, fileValue: FileValue, isValid: boolean) => void} onChange + * @param {Connection?} originalConnection + * @param {string?} cachedFile + * @returns {HTMLElement} + */ +const SalesforceData360Form = ( + connection, + flavor, + onChange, + originalConnection, + cachedFile, +) => { + const isValid = van.state(false); + const authMethod = van.state( + originalConnection?.connection_id + ? (connection.rawVal.connect_by_key ? 'jwt' : 'client_credentials') + : 'jwt' + ); + const loginUrl = van.state(connection.rawVal.project_host ?? ''); + const consumerKey = van.state(connection.rawVal.project_user ?? ''); + const consumerSecret = van.state(connection.rawVal?.project_pw_encrypted ?? ''); + const permittedUser = van.state(connection.rawVal.project_db ?? ''); + const connectionPrivateKey = van.state(connection.rawVal?.private_key ?? ''); + + const validityPerField = {}; + + const privateKeyFileRaw = van.state(cachedFile); + + van.derive(() => { + onChange({ + project_host: loginUrl.val, + project_user: consumerKey.val, + project_pw_encrypted: consumerSecret.val, + project_db: permittedUser.val, + connect_by_key: authMethod.val === 'jwt', + private_key: connectionPrivateKey.val, + }, privateKeyFileRaw.val, isValid.val); + }); + + return div( + { class: 'flex-column fx-gap-3 fx-flex' }, + div( + { class: 'flex-column border border-radius-1 p-3 mt-1 fx-gap-1', style: 'position: relative;' }, + Caption({ content: 'Org', style: 'position: absolute; top: -10px; background: var(--app-background-color); padding: 0px 8px;' }), + Input({ + name: 'login_url', + label: 'Login URL', + help: 'My Domain URL of the Salesforce org', + value: loginUrl, + onChange: (value, state) => { + loginUrl.val = value; + validityPerField['login_url'] = state.valid; + isValid.val = Object.values(validityPerField).every(v => v); + }, + validators: [required, maxLength(250)], + }), + ), + div( + { class: 'flex-column border border-radius-1 p-3 mt-1 fx-gap-1', style: 'position: relative;' }, + Caption({ content: 'Authentication', style: 'position: absolute; top: -10px; background: var(--app-background-color); padding: 0px 8px;' }), + RadioGroup({ + label: 'Connection Strategy', + options: [ + { label: 'JWT Bearer Flow', value: 'jwt' }, + { label: 'Client Credentials Flow', value: 'client_credentials' }, + ], + value: authMethod, + onChange: (value) => authMethod.val = value, + layout: 'inline', + }), + Input({ + name: 'consumer_key', + label: 'Consumer Key', + help: 'Consumer key from the Salesforce external client app', + value: consumerKey, + onChange: (value, state) => { + consumerKey.val = value; + validityPerField['consumer_key'] = state.valid; + isValid.val = Object.values(validityPerField).every(v => v); + }, + validators: [required, maxLength(250)], + }), + () => { + if (authMethod.val === 'jwt') { + return div( + { class: 'flex-column fx-gap-3' }, + Input({ + name: 'permitted_user', + label: 'Username', + help: 'Salesforce user the JWT token will impersonate. Must be pre-authorized on the external client app.', + value: permittedUser, + onChange: (value, state) => { + permittedUser.val = value; + validityPerField['permitted_user'] = state.valid; + isValid.val = Object.values(validityPerField).every(v => v); + }, + validators: [ + requiredIf(() => authMethod.val === 'jwt'), + maxLength(250), + ], + }), + FileInput({ + name: 'private_key', + label: 'Upload private key (.pem, .key)', + placeholder: (originalConnection?.connection_id && originalConnection?.private_key) + ? 'Drop file here or browse files to replace existing key' + : undefined, + value: privateKeyFileRaw, + onChange: (value, state) => { + let isFieldValid = state.valid; + + privateKeyFileRaw.val = value; + try { + if (value?.content) { + connectionPrivateKey.val = value.content.split(',')?.[1] ?? ''; + } + } catch (err) { + console.error(err); + isFieldValid = false; + } + + validityPerField['private_key'] = isFieldValid; + isValid.val = Object.values(validityPerField).every(v => v); + }, + validators: [ + requiredIf(() => authMethod.val === 'client_credentials' && !originalConnection?.connection_id || !originalConnection?.private_key), + sizeLimit(200 * 1024 * 1024), + ], + }), + ); + }; + return Input({ + name: 'consumer_secret', + label: 'Consumer Secret', + help: 'Consumer secret from the Salesforce external client app', + type: 'password', + passwordSuggestions: false, + value: consumerSecret, + placeholder: (originalConnection?.connection_id && originalConnection?.project_pw_encrypted) ? secretsPlaceholder : '', + onChange: (value, state) => { + consumerSecret.val = value; + validityPerField['consumer_secret'] = state.valid; + isValid.val = Object.values(validityPerField).every(v => v); + }, + validators: [ + requiredIf(() => authMethod.val === 'jwt' && !originalConnection?.connection_id || !originalConnection?.project_pw_encrypted), + ], + }); + }, + ), + ); +}; + /** * @param {VanState} connection * @param {Flavor} flavor diff --git a/testgen/ui/static/js/components/table_group_form.js b/testgen/ui/static/js/components/table_group_form.js index 46dd4353..93becaa9 100644 --- a/testgen/ui/static/js/components/table_group_form.js +++ b/testgen/ui/static/js/components/table_group_form.js @@ -1,6 +1,6 @@ /** * @import { Connection } from './connection_form.js'; - * + * * @typedef TableGroup * @type {object} * @property {string?} id @@ -29,12 +29,12 @@ * @property {string?} stakeholder_group * @property {string?} transform_level * @property {string?} data_product - * + * * @typedef FormState * @type {object} * @property {boolean} dirty * @property {boolean} valid - * + * * @typedef Properties * @type {object} * @property {TableGroup} tableGroup @@ -42,6 +42,7 @@ * @property {boolean?} showConnectionSelector * @property {boolean?} disableConnectionSelector * @property {boolean?} disableSchemaField + * @property {string?} sqlFlavor * @property {boolean?} disablePiiFlag * @property {(tg: TableGroup, state: FormState) => void} onChange */ @@ -65,9 +66,9 @@ const normalizeTableSet = (value) => { } /** - * - * @param {Properties} props - * @returns + * + * @param {Properties} props + * @returns */ const TableGroupForm = (props) => { loadStylesheet('table-group-form', stylesheet); @@ -111,6 +112,13 @@ const TableGroupForm = (props) => { const showConnectionSelector = getValue(props.showConnectionSelector) ?? false; const disableSchemaField = van.derive(() => getValue(props.disableSchemaField) ?? false) + const isSalesforce = van.derive(() => { + const connections = getValue(props.connections) ?? []; + const selected = connections.find(c => c.connection_id === tableGroupConnectionId.val); + const flavor = selected?.sql_flavor ?? getValue(props.sqlFlavor); + return flavor === 'salesforce_data360'; + }); + const updatedTableGroup = van.derive(() => { return { id: tableGroup.id, @@ -176,7 +184,7 @@ const TableGroupForm = (props) => { }) : undefined, MainForm( - { disableSchemaField, setValidity: setFieldValidity }, + { disableSchemaField, isSalesforce, setValidity: setFieldValidity }, tableGroupsName, tableGroupSchema, ), @@ -238,12 +246,14 @@ const MainForm = ( }, validators: [ required ], }), - Input({ + () => Input({ name: 'table_group_schema', - label: 'Schema', + label: getValue(options.isSalesforce) ? 'Data Space' : 'Schema', value: tableGroupSchema, class: 'tg-column-flex', - help: 'Database schema containing the tables for the Table Group', + help: getValue(options.isSalesforce) + ? 'Salesforce data space containing the tables for the Table Group' + : 'Database schema containing the tables for the Table Group', helpPlacement: 'bottom-left', disabled: options.disableSchemaField, onChange: (value, state) => { @@ -340,7 +350,7 @@ const SettingsForm = ( ) => { return div( { class: 'flex-row fx-gap-3 fx-flex-wrap fx-align-flex-start border border-radius-1 p-3 mt-1', style: 'position: relative;' }, - Caption({content: 'Settings', style: 'position: absolute; top: -10px; background: var(--app-background-color); padding: 0px 8px;' }), + Caption({content: 'Settings', style: 'position: absolute; top: -10px; background: var(--app-background-color); padding: 0px 8px;' }), div( { class: 'tg-column-flex flex-column fx-gap-3' }, Checkbox({ diff --git a/testgen/ui/static/js/components/test_definition_form.js b/testgen/ui/static/js/components/test_definition_form.js index 18b173dc..e7fa14d6 100644 --- a/testgen/ui/static/js/components/test_definition_form.js +++ b/testgen/ui/static/js/components/test_definition_form.js @@ -60,6 +60,7 @@ * @type {object} * @property {TestDefinition} definition * @property {string?} class + * @property {boolean} qualifiesTableRefsWithSchema * @property {(changes: object, valid: boolean) => void} onChange */ @@ -83,7 +84,7 @@ const thresholdColumns = [ ]; // Columns using the default { type: 'text' } do not need to be specified here -const PARAMETER_CONFIG = { +const PARAMETER_CONFIG = { custom_query: { type: 'textarea' }, lower_tolerance: { type: 'number' }, upper_tolerance: { type: 'number' }, @@ -94,6 +95,7 @@ const TestDefinitionForm = (/** @type Properties */ props) => { loadStylesheet('test-definition-form', stylesheet); const definition = getValue(props.definition); + const qualifiesTableRefsWithSchema = getValue(props.qualifiesTableRefsWithSchema) ?? true; const paramColumns = (definition.default_parm_columns || '').split(',').map(v => v.trim()); const paramLabels = (definition.default_parm_prompts || '').split(',').map(v => v.trim()); @@ -110,6 +112,8 @@ const TestDefinitionForm = (/** @type Properties */ props) => { validators: paramRequired[index] ? [required] : undefined, })) .filter(config => !hasThresholds || !thresholdColumns.includes(config.column)) + // Drop the field for flavors whose SQL doesn't qualify table refs with a schema + .filter(config => qualifiesTableRefsWithSchema || config.column !== 'match_schema_name') const updatedDefinition = van.state({ ...definition }); const validityPerField = van.state({}); @@ -258,9 +262,9 @@ const historyCalcOptions = [ * @type {object} * @property {(updatedValues: object) => void} setFieldValues * @property {(field: string, valid: boolean) => void} setFieldValidity - * - * @param {ThresholdFormOptions} options - * @param {TestDefinition} definition + * + * @param {ThresholdFormOptions} options + * @param {TestDefinition} definition */ const ThresholdForm = (options, definition) => { const { setFieldValues, setFieldValidity } = options; diff --git a/testgen/ui/views/connections.py b/testgen/ui/views/connections.py index a1452bb6..5faaf0a4 100644 --- a/testgen/ui/views/connections.py +++ b/testgen/ui/views/connections.py @@ -638,6 +638,12 @@ class ConnectionFlavor: flavor="sap_hana", icon=get_asset_data_url("flavors/sap_hana.svg"), ), + ConnectionFlavor( + label="Salesforce Data 360", + value="salesforce_data360", + flavor="salesforce_data360", + icon=get_asset_data_url("flavors/salesforce_data360.svg"), + ), ConnectionFlavor( label="Snowflake", value="snowflake", diff --git a/testgen/ui/views/data_catalog.py b/testgen/ui/views/data_catalog.py index 404b85ff..0f09c5bf 100644 --- a/testgen/ui/views/data_catalog.py +++ b/testgen/ui/views/data_catalog.py @@ -117,7 +117,7 @@ def render(self, project_code: str, table_group_id: str | None = None, selected: selected_item["connection_id"] = str(selected_table_group.connection_id) else: on_item_selected(None) - + def on_run_profiling_clicked(_) -> None: if selected_table_group: st.session_state[DC_RUN_PROFILING_DIALOG_KEY] = str(selected_table_group.id) @@ -466,7 +466,7 @@ def get_excel_report_data(update_progress: PROGRESS_UPDATE_TYPE, table_group: Ta include_tags=True, include_active_tests=True, ) - + data = pd.DataFrame(table_data + column_data) @@ -684,7 +684,7 @@ def on_tags_changed(spinner_container: DeltaGenerator, payload: dict) -> FILE_DA def get_table_group_columns(table_group_id: str) -> list[dict]: if not is_uuid4(table_group_id): return [] - + query = f""" SELECT CONCAT('column_', column_chars.column_id) AS column_id, CONCAT('table_', table_chars.table_id) AS table_id, @@ -902,11 +902,12 @@ def get_preview_data( flavor_service = get_flavor_service(connection.sql_flavor) prefix, suffix = flavor_service.row_limit_clauses(100) quote = flavor_service.quote_character + table_ref = flavor_service.get_table_ref(schema_name, table_name) query = f""" SELECT DISTINCT {prefix} {f"{quote}{column_name}{quote}" if column_name else "*"} - FROM {quote}{schema_name}{quote}.{quote}{table_name}{quote} + FROM {table_ref} {suffix} """ diff --git a/testgen/ui/views/test_definitions.py b/testgen/ui/views/test_definitions.py index a73f30c9..12127bf9 100644 --- a/testgen/ui/views/test_definitions.py +++ b/testgen/ui/views/test_definitions.py @@ -184,6 +184,12 @@ def render( # Build dialog states validate_result = st.session_state.pop(TD_VALIDATE_RESULT_KEY, None) + qualifies_table_refs_with_schema = True + if st.session_state.get(TD_ADD_DIALOG_KEY) or st.session_state.get(TD_EDIT_DIALOG_KEY): + connection = Connection.get(table_group.connection_id) + if connection: + qualifies_table_refs_with_schema = get_flavor_service(connection.sql_flavor).qualifies_table_refs_with_schema + add_dialog = None if st.session_state.get(TD_ADD_DIALOG_KEY): add_dialog = { @@ -193,6 +199,7 @@ def render( "table_groups_id": str(table_group.id), "table_group_schema": table_group.table_group_schema, "test_suite": test_suite_info, + "qualifies_table_refs_with_schema": qualifies_table_refs_with_schema, } edit_dialog = None @@ -204,6 +211,7 @@ def render( "table_columns": table_columns, "table_group_schema": table_group.table_group_schema, "test_suite": test_suite_info, + "qualifies_table_refs_with_schema": qualifies_table_refs_with_schema, } delete_dialog = None @@ -320,9 +328,27 @@ def on_copy_move_dialog_closed(*_) -> None: st.session_state.pop(TD_COPY_MOVE_COLLISION_KEY, None) st.session_state.pop(TD_COPY_MOVE_OVERWRITE_KEY, None) + match_schema_test_types = { + tt["test_type"] + for tt in test_types + if "match_schema_name" in (tt.get("default_parm_columns") or "").split(",") + } + + def _default_match_schema(test_def: dict) -> None: + # The Match Schema field is hidden in the UI for flavors whose SQL doesn't + # qualify table refs with a schema, but downstream SQL/Python still expects + # match_schema_name populated for tests that support it. Default to the + # test's schema (or table-group schema) when the test type accepts + # match_schema_name and match_table_name is set. + if test_def.get("test_type") not in match_schema_test_types: + return + if test_def.get("match_table_name") and not test_def.get("match_schema_name"): + test_def["match_schema_name"] = test_def.get("schema_name") or table_group.table_group_schema + @with_database_session def on_add_test_saved(test_def: dict) -> None: test_def["last_manual_update"] = datetime.now(UTC) + _default_match_schema(test_def) td_columns = set(TestDefinition.__table__.columns.keys()) TestDefinition(**{k: v for k, v in test_def.items() if k in td_columns}).save() st.cache_data.clear() @@ -332,6 +358,7 @@ def on_add_test_saved(test_def: dict) -> None: @with_database_session def on_edit_test_saved(test_def: dict) -> None: test_def["last_manual_update"] = datetime.now(UTC) + _default_match_schema(test_def) td_columns = set(TestDefinition.__table__.columns.keys()) TestDefinition(**{k: v for k, v in test_def.items() if k in td_columns}).save() st.cache_data.clear() @@ -955,7 +982,7 @@ def validate_test(test_definition: dict, table_group: TableGroupMinimal) -> None condition = test_definition["custom_query"] flavor_service = get_flavor_service(connection.sql_flavor) concat_operator = flavor_service.concat_operator - quote = flavor_service.quote_character + table_ref = flavor_service.get_table_ref(schema, table_name) query = f""" SELECT COALESCE( @@ -967,7 +994,7 @@ def validate_test(test_definition: dict, table_group: TableGroupMinimal) -> None {concat_operator} '|', '|' ) - FROM {quote}{schema}{quote}.{quote}{table_name}{quote}; + FROM {table_ref}; """ fetch_from_target_db(connection, query) else: diff --git a/tests/unit/commands/queries/test_refresh_data_chars_query.py b/tests/unit/commands/queries/test_refresh_data_chars_query.py index 47c52179..5dc3ff16 100644 --- a/tests/unit/commands/queries/test_refresh_data_chars_query.py +++ b/tests/unit/commands/queries/test_refresh_data_chars_query.py @@ -1,12 +1,20 @@ import pytest from testgen.commands.queries.refresh_data_chars_query import RefreshDataCharsSQL +from testgen.common.database.column_chars import ColumnChars from testgen.common.models.connection import Connection from testgen.common.models.table_group import TableGroup pytestmark = pytest.mark.unit +def _make_columns(*table_names: str) -> list[ColumnChars]: + return [ + ColumnChars(schema_name="default", table_name=name, column_name="id") + for name in table_names + ] + + @pytest.mark.parametrize( "flavor,expected_sql", [ @@ -127,3 +135,85 @@ def test_table_set_with_include_exclude(): assert "LIKE 'important%'" in criteria assert "AND NOT" in criteria assert "LIKE 'temp%'" in criteria + + +def test_filter_schema_columns_table_set(): + connection = Connection(sql_flavor="salesforce_data360") + table_group = TableGroup( + table_group_schema="default", + profiling_table_set="users, orders", + profiling_include_mask="", + profiling_exclude_mask="", + ) + sql_generator = RefreshDataCharsSQL(connection, table_group) + columns = _make_columns("users", "orders", "products", "logs") + + filtered = sql_generator.filter_schema_columns(columns) + + assert {c.table_name for c in filtered} == {"users", "orders"} + + +def test_filter_schema_columns_include_mask(): + connection = Connection(sql_flavor="salesforce_data360") + table_group = TableGroup( + table_group_schema="default", + profiling_table_set="", + profiling_include_mask="party_%, summary", + profiling_exclude_mask="", + ) + sql_generator = RefreshDataCharsSQL(connection, table_group) + columns = _make_columns("party_planners", "party_transactions", "summary", "audit_log") + + filtered = sql_generator.filter_schema_columns(columns) + + assert {c.table_name for c in filtered} == {"party_planners", "party_transactions", "summary"} + + +def test_filter_schema_columns_exclude_mask(): + connection = Connection(sql_flavor="salesforce_data360") + table_group = TableGroup( + table_group_schema="default", + profiling_table_set="", + profiling_include_mask="", + profiling_exclude_mask="tmp_%, raw_log", + ) + sql_generator = RefreshDataCharsSQL(connection, table_group) + columns = _make_columns("users", "tmp_x", "tmp_y", "raw_log", "orders") + + filtered = sql_generator.filter_schema_columns(columns) + + assert {c.table_name for c in filtered} == {"users", "orders"} + + +def test_filter_schema_columns_underscore_is_literal(): + # SQL LIKE _ wildcard semantics: the existing SQL path escapes user `_` to `\_`, + # treating `_` as a literal. The Python filter must match that behavior. + connection = Connection(sql_flavor="salesforce_data360") + table_group = TableGroup( + table_group_schema="default", + profiling_table_set="", + profiling_include_mask="a_b", + profiling_exclude_mask="", + ) + sql_generator = RefreshDataCharsSQL(connection, table_group) + columns = _make_columns("a_b", "axb", "axxb") + + filtered = sql_generator.filter_schema_columns(columns) + + assert {c.table_name for c in filtered} == {"a_b"} + + +def test_filter_schema_columns_no_filters_returns_all(): + connection = Connection(sql_flavor="salesforce_data360") + table_group = TableGroup( + table_group_schema="default", + profiling_table_set="", + profiling_include_mask="", + profiling_exclude_mask="", + ) + sql_generator = RefreshDataCharsSQL(connection, table_group) + columns = _make_columns("users", "orders") + + filtered = sql_generator.filter_schema_columns(columns) + + assert {c.table_name for c in filtered} == {"users", "orders"} diff --git a/tests/unit/common/test_salesforce_data360_flavor.py b/tests/unit/common/test_salesforce_data360_flavor.py new file mode 100644 index 00000000..48bcab65 --- /dev/null +++ b/tests/unit/common/test_salesforce_data360_flavor.py @@ -0,0 +1,415 @@ +"""Unit tests for Salesforce Data 360 flavor support.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from testgen.common.database.flavor.flavor_service import ResolvedConnectionParams, resolve_connection_params +from testgen.common.database.flavor.salesforce_data360_flavor_service import ( + _TYPE_MAP, + SalesforceData360FlavorService, +) + + +@pytest.fixture +def flavor_service(): + return SalesforceData360FlavorService() + + +@pytest.fixture +def client_credentials_params(): + return ResolvedConnectionParams( + host="https://myorg.my.salesforce.com", + username="consumer_key_123", + password="consumer_secret_456", # noqa: S106 + dbname="", + connect_by_key=False, + sql_flavor="salesforce_data360", + ) + + +@pytest.fixture +def jwt_bearer_params(): + return ResolvedConnectionParams( + host="https://myorg.my.salesforce.com", + username="consumer_key_123", + dbname="admin@myorg.com", + connect_by_key=True, + private_key="-----BEGIN RSA PRIVATE KEY-----\ntest\n-----END RSA PRIVATE KEY-----", + sql_flavor="salesforce_data360", + ) + + +# --- FlavorService class properties --- + +def test_flavor_service_properties(flavor_service): + assert flavor_service.concat_operator == "||" + assert flavor_service.quote_character == '"' + assert flavor_service.varchar_type == "VARCHAR(1000)" + assert flavor_service.default_uppercase is False + assert flavor_service.test_query == "SELECT 1" + assert flavor_service.qualifies_table_refs_with_schema is False + assert flavor_service.metadata_via_api is True + assert flavor_service.row_limiting_clause == "limit" + + +def test_get_table_ref_omits_schema(flavor_service): + assert flavor_service.get_table_ref("data_space", "Account__dll") == '"Account__dll"' + + +# --- Connection string --- + +def test_connection_string_is_dummy(flavor_service, client_credentials_params): + assert flavor_service.get_connection_string(client_credentials_params) == "salesforce_data360://" + + +def test_connection_string_from_fields(flavor_service, client_credentials_params): + assert flavor_service.get_connection_string_from_fields(client_credentials_params) == "salesforce_data360://" + + +# --- Connect args: Client Credentials flow --- + +def test_connect_args_client_credentials(flavor_service, client_credentials_params): + args = flavor_service.get_connect_args(client_credentials_params) + assert args["login_url"] == "https://myorg.my.salesforce.com" + assert args["client_id"] == "consumer_key_123" + assert args["client_secret"] == "consumer_secret_456" # noqa: S105 + assert "username" not in args + assert "private_key" not in args + assert "dataspace" not in args # connection-only contexts (Test Connection) + + +# --- Connect args: JWT Bearer flow --- + +def test_connect_args_jwt_bearer(flavor_service, jwt_bearer_params): + args = flavor_service.get_connect_args(jwt_bearer_params) + assert args["login_url"] == "https://myorg.my.salesforce.com" + assert args["client_id"] == "consumer_key_123" + assert args["username"] == "admin@myorg.com" + assert args["private_key"].startswith("-----BEGIN RSA PRIVATE KEY-----") + assert "client_secret" not in args + assert "dataspace" not in args # connection-only contexts (Test Connection) + + +# --- Connect args: Data Space pass-through --- + +def test_connect_args_passes_dataspace_when_table_group_schema_set(flavor_service): + params = ResolvedConnectionParams( + host="https://myorg.my.salesforce.com", + username="consumer_key_123", + password="consumer_secret_456", # noqa: S106 + dbname="", + dbschema="marketing", + connect_by_key=False, + sql_flavor="salesforce_data360", + ) + args = flavor_service.get_connect_args(params) + assert args["dataspace"] == "marketing" + + +def test_connect_args_omits_dataspace_when_table_group_schema_empty(flavor_service): + params = ResolvedConnectionParams( + host="https://myorg.my.salesforce.com", + username="consumer_key_123", + dbname="admin@myorg.com", + dbschema="", + connect_by_key=True, + private_key="-----BEGIN RSA PRIVATE KEY-----\ntest\n-----END RSA PRIVATE KEY-----", + sql_flavor="salesforce_data360", + ) + args = flavor_service.get_connect_args(params) + assert "dataspace" not in args + + +# --- Engine args --- + +def test_engine_args(flavor_service, client_credentials_params): + args = flavor_service.get_engine_args(client_credentials_params) + assert args["pool_pre_ping"] is False + assert "poolclass" in args + + +# --- Pre-connection queries --- + +def test_no_pre_connection_queries(flavor_service, client_credentials_params): + assert flavor_service.get_pre_connection_queries(client_credentials_params) == [] + + +# --- Table reference (no schema prefix) --- + +def test_get_table_ref_no_schema(flavor_service): + ref = flavor_service.get_table_ref("default", "ssot__Account__dlm") + assert ref == '"ssot__Account__dlm"' + assert "default" not in ref + + +# --- resolve_connection_params mapping --- + +def test_resolve_connection_params_mapping(): + # Use plain strings (not bytes) to avoid triggering the DecryptText path + params = resolve_connection_params({ + "sql_flavor": "salesforce_data360", + "project_host": "https://myorg.my.salesforce.com", + "project_user": "consumer_key", + "project_pw_encrypted": "plain_secret", + "project_db": "admin@org.com", + "connect_by_key": True, + "private_key": "plain_key", + }) + assert params.host == "https://myorg.my.salesforce.com" + assert params.username == "consumer_key" + assert params.password == "plain_secret" # noqa: S105 + assert params.dbname == "admin@org.com" + assert params.connect_by_key is True + assert params.private_key == "plain_key" + + +# --- Schema metadata (get_schema_columns) --- + +def test_get_schema_columns_returns_columns(flavor_service, client_credentials_params): + mock_field = MagicMock() + mock_field.name = "ssot__Name__c" + mock_field.type = "STRING" + + mock_table = MagicMock() + mock_table.name = "ssot__Account__dlm" + mock_table.fields = [mock_field] + + mock_conn = MagicMock() + mock_conn.list_tables.return_value = [mock_table] + + with patch( + "salesforcecdpconnector.connection.SalesforceCDPConnection", + return_value=mock_conn, + ): + columns = flavor_service.get_schema_columns(client_credentials_params, "default") + + assert columns is not None + assert len(columns) == 1 + assert columns[0].schema_name == "default" + assert columns[0].table_name == "ssot__Account__dlm" + assert columns[0].column_name == "ssot__Name__c" + assert columns[0].column_type == "varchar" + assert columns[0].general_type == "A" + assert columns[0].db_data_type == "STRING" + assert columns[0].ordinal_position == 1 + assert columns[0].is_decimal is False + + +def test_get_schema_columns_type_mapping(flavor_service, client_credentials_params): + """Verify all metadata types map correctly.""" + type_cases = [ + ("STRING", "varchar", "A", False), + ("NUMBER", "numeric", "N", True), + ("BIGINT", "bigint", "N", False), + ("BOOLEAN", "boolean", "B", False), + ("DATE", "date", "D", False), + ("DATE_TIME", "datetime", "D", False), + ] + + for meta_type, expected_col_type, expected_gen_type, expected_decimal in type_cases: + mock_field = MagicMock() + mock_field.name = "test_col" + mock_field.type = meta_type + + mock_table = MagicMock() + mock_table.name = "test_table" + mock_table.fields = [mock_field] + + mock_conn = MagicMock() + mock_conn.list_tables.return_value = [mock_table] + + with patch( + "salesforcecdpconnector.connection.SalesforceCDPConnection", + return_value=mock_conn, + ): + columns = flavor_service.get_schema_columns(client_credentials_params, "default") + + assert columns[0].column_type == expected_col_type, f"Failed for {meta_type}" + assert columns[0].general_type == expected_gen_type, f"Failed for {meta_type}" + assert columns[0].is_decimal == expected_decimal, f"Failed for {meta_type}" + + +def test_get_schema_columns_unknown_type_defaults_to_X(flavor_service, client_credentials_params): + mock_field = MagicMock() + mock_field.name = "exotic_col" + mock_field.type = "HYPERLOGLOG" + + mock_table = MagicMock() + mock_table.name = "test_table" + mock_table.fields = [mock_field] + + mock_conn = MagicMock() + mock_conn.list_tables.return_value = [mock_table] + + with patch( + "salesforcecdpconnector.connection.SalesforceCDPConnection", + return_value=mock_conn, + ): + columns = flavor_service.get_schema_columns(client_credentials_params, "default") + + assert columns[0].general_type == "X" + # Unknown metadata types are preserved as a lowercased column_type so that + # downstream views still surface the raw SF type instead of coercing to varchar. + assert columns[0].column_type == "hyperloglog" + + +def test_get_schema_columns_multiple_tables(flavor_service, client_credentials_params): + tables = [] + for tname, field_count in [("ssot__Account__dlm", 3), ("ssot__Individual__dlm", 5)]: + mock_table = MagicMock() + mock_table.name = tname + mock_table.fields = [] + for i in range(field_count): + f = MagicMock() + f.name = f"field_{i}" + f.type = "STRING" + mock_table.fields.append(f) + tables.append(mock_table) + + mock_conn = MagicMock() + mock_conn.list_tables.return_value = tables + + with patch( + "salesforcecdpconnector.connection.SalesforceCDPConnection", + return_value=mock_conn, + ): + columns = flavor_service.get_schema_columns(client_credentials_params, "default") + + assert len(columns) == 8 + account_cols = [c for c in columns if c.table_name == "ssot__Account__dlm"] + assert len(account_cols) == 3 + individual_cols = [c for c in columns if c.table_name == "ssot__Individual__dlm"] + assert len(individual_cols) == 5 + + +# --- Dialect registration --- + +def test_dialect_is_registered(): + from sqlalchemy.dialects import registry as sa_registry + + # The import of the flavor service module triggers registration + assert "salesforce_data360" in sa_registry.impls + + +# --- Type map completeness --- + +def test_type_map_covers_all_known_types(): + # Data 360's metadata API has a small fixed vocabulary verified against + # profiled DMOs and DLOs. Any unknown type falls through to general_type "X". + expected_types = {"STRING", "NUMBER", "BIGINT", "BOOLEAN", "DATE", "DATE_TIME"} + assert set(_TYPE_MAP.keys()) == expected_types + + +# --- SQL template files exist --- + +def test_template_files_exist(): + from pathlib import Path + + base = Path(__file__).parents[3] / "testgen" / "template" / "flavors" / "salesforce_data360" + assert (base / "profiling" / "project_profiling_query.sql").exists() + assert (base / "profiling" / "project_secondary_profiling_query.sql").exists() + assert (base / "profiling" / "templated_functions.yaml").exists() + + +# --- Templated functions YAML --- + +def test_templated_functions_yaml_parses(): + from pathlib import Path + + import yaml + + path = Path(__file__).parents[3] / "testgen" / "template" / "flavors" / "salesforce_data360" / "profiling" / "templated_functions.yaml" + with open(path) as f: + data = yaml.safe_load(f) + + # Data 360 uses native DATEDIFF('unit', ...) directly in templates, so the + # DATEDIFF_* macros are intentionally omitted (only IS_NUM / IS_DATE need wrappers). + required_functions = ["IS_NUM", "IS_DATE"] + for func_name in required_functions: + assert func_name in data, f"Missing templated function: {func_name}" + + +def test_profiling_query_uses_data360_datediff_syntax(): + from pathlib import Path + + path = Path(__file__).parents[3] / "testgen" / "template" / "flavors" / "salesforce_data360" / "profiling" / "project_profiling_query.sql" + sql = path.read_text() + + # Data 360 uses inline DATEDIFF('unit', start, end) — string units, not bare identifiers. + assert "DATEDIFF('day'" in sql + assert "DATEDIFF('week'" in sql + assert "DATEDIFF('month'" in sql + + +def test_is_num_uses_regexp_like(): + from pathlib import Path + + import yaml + + path = Path(__file__).parents[3] / "testgen" / "template" / "flavors" / "salesforce_data360" / "profiling" / "templated_functions.yaml" + with open(path) as f: + data = yaml.safe_load(f) + + assert "REGEXP_LIKE" in data["IS_NUM"] + assert "~" not in data["IS_NUM"] # No PG regex operator + + +def test_is_date_uses_regexp_like(): + from pathlib import Path + + import yaml + + path = Path(__file__).parents[3] / "testgen" / "template" / "flavors" / "salesforce_data360" / "profiling" / "templated_functions.yaml" + with open(path) as f: + data = yaml.safe_load(f) + + assert "REGEXP_LIKE" in data["IS_DATE"] + assert "~" not in data["IS_DATE"] + assert "LEFT(" not in data["IS_DATE"] # Should use SUBSTR, not LEFT + assert "::" not in data["IS_DATE"] # Should use CAST, not :: + + +# --- Profiling template syntax checks --- + +def test_profiling_query_has_no_pg_specific_syntax(): + from pathlib import Path + + path = Path(__file__).parents[3] / "testgen" / "template" / "flavors" / "salesforce_data360" / "profiling" / "project_profiling_query.sql" + content = path.read_text() + + assert "TABLESAMPLE" not in content + assert "STRING_AGG" not in content + assert "TRANSLATE(" not in content + assert " ~ " not in content # PG regex operator + # Check for PG escape string syntax (E'...') — but not substrings like "CASE '" + import re + assert not re.search(r"\bE'", content), "Found PostgreSQL E-string escape syntax" + assert "LEFT(" not in content + assert "::FLOAT" not in content + assert "::BIGINT" not in content + assert "::NUMERIC" not in content + + +def test_profiling_query_uses_data360_alternatives(): + from pathlib import Path + + path = Path(__file__).parents[3] / "testgen" / "template" / "flavors" / "salesforce_data360" / "profiling" / "project_profiling_query.sql" + content = path.read_text() + + assert "REGEXP_LIKE" in content + assert "ARRAY_JOIN(ARRAY_AGG" in content + assert "SUBSTR(" in content + assert "ORDER BY RANDOM()" in content + + +def test_secondary_profiling_query_syntax(): + from pathlib import Path + + path = Path(__file__).parents[3] / "testgen" / "template" / "flavors" / "salesforce_data360" / "profiling" / "project_secondary_profiling_query.sql" + content = path.read_text() + + assert "STRING_AGG" not in content + assert "ARRAY_JOIN(ARRAY_AGG" in content + assert "TABLESAMPLE" not in content + assert '"{DATA_SCHEMA}".' not in content # No schema prefix in FROM From 918088c8736bf3e10508d2cf930f01ee13510a52 Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Tue, 19 May 2026 16:04:57 -0400 Subject: [PATCH 27/58] feat(mcp): schedule CRUD tools (TG-1071) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Six tools for managing recurring profiling and test-run schedules: create_profiling_schedule, create_test_run_schedule, list_schedules, get_schedule, update_schedule, delete_schedule. Also: - Surface ``schedule_id`` on existing run renderers (test_runs, profiling) - Lift cron validation helper to ``common/cron_service.py`` (was in ``ui/utils.py``) - Rename ``JobSchedule.select_where`` → ``select_runnable``; add clean generic ``select_where`` - Drop ``@st.cache_data`` from ``JobSchedule.get`` Co-Authored-By: Claude Opus 4.7 (1M context) --- testgen/common/cron_service.py | 50 +++ testgen/common/models/profiling_run.py | 2 + testgen/common/models/scheduler.py | 38 +- testgen/common/models/test_run.py | 2 + testgen/mcp/server.py | 14 + testgen/mcp/tools/common.py | 16 +- testgen/mcp/tools/profiling.py | 4 + testgen/mcp/tools/schedules.py | 441 +++++++++++++++++++ testgen/mcp/tools/test_runs.py | 4 + testgen/scheduler/cli_scheduler.py | 2 +- testgen/ui/utils.py | 39 +- testgen/ui/views/dialogs/manage_schedules.py | 2 +- testgen/ui/views/monitors_dashboard.py | 3 +- tests/unit/mcp/test_tools_schedules.py | 358 +++++++++++++++ tests/unit/scheduler/test_scheduler_cli.py | 2 +- 15 files changed, 931 insertions(+), 46 deletions(-) create mode 100644 testgen/common/cron_service.py create mode 100644 testgen/mcp/tools/schedules.py create mode 100644 tests/unit/mcp/test_tools_schedules.py diff --git a/testgen/common/cron_service.py b/testgen/common/cron_service.py new file mode 100644 index 00000000..afe75485 --- /dev/null +++ b/testgen/common/cron_service.py @@ -0,0 +1,50 @@ +import zoneinfo +from datetime import datetime +from typing import TypedDict + +import cron_converter +import cron_descriptor + + +class CronSample(TypedDict, total=False): + id: str | None + error: str | None + samples: list[str] | list[int] | None + readable_expr: str | None + + +def get_cron_sample( + cron_expr: str, + cron_tz: str, + sample_count: int, + *, + reference_time: datetime | None = None, + formatted: bool = False, +) -> CronSample: + try: + cron_obj = cron_converter.Cron(cron_expr) + cron_schedule = cron_obj.schedule(reference_time or datetime.now(zoneinfo.ZoneInfo(cron_tz))) + readable_cron_schedule = cron_descriptor.get_description(cron_expr) + if formatted: + samples = [cron_schedule.next().strftime("%a %b %-d, %-I:%M %p") for _ in range(sample_count)] + else: + samples = [int(cron_schedule.next().timestamp()) for _ in range(sample_count)] + except zoneinfo.ZoneInfoNotFoundError: + return {"error": f"Unknown timezone `{cron_tz}`. Use an IANA name (e.g. `America/New_York`)."} + except ValueError as e: + return {"error": str(e)} + except Exception: + return {"error": "Error validating the Cron expression"} + else: + return { + "samples": samples, + "readable_expr": readable_cron_schedule, + } + + +def describe_cron(cron_expr: str) -> str | None: + """Human-readable description of a cron expression, e.g. ``At 04:00 AM``. Returns ``None`` if unparseable.""" + try: + return cron_descriptor.get_description(cron_expr) + except Exception: + return None diff --git a/testgen/common/models/profiling_run.py b/testgen/common/models/profiling_run.py index 9faf8bf0..20e69a27 100644 --- a/testgen/common/models/profiling_run.py +++ b/testgen/common/models/profiling_run.py @@ -48,6 +48,7 @@ class ProfilingRunMinimal(EntityMinimal): class ProfilingRunSummary(EntityMinimal): job_execution_id: UUID profiling_run_id: UUID | None + job_schedule_id: UUID | None project_code: str status: JobStatus created_at: datetime @@ -243,6 +244,7 @@ def select_summary( SELECT je.id AS job_execution_id, pr.id AS profiling_run_id, + je.job_schedule_id, je.project_code, je.status, je.created_at, diff --git a/testgen/common/models/scheduler.py b/testgen/common/models/scheduler.py index dda95383..f294cf0a 100644 --- a/testgen/common/models/scheduler.py +++ b/testgen/common/models/scheduler.py @@ -3,14 +3,13 @@ from typing import Any, Self from uuid import UUID, uuid4 -import streamlit as st from cron_converter import Cron from sqlalchemy import Boolean, Column, String, cast, delete, func, select, update from sqlalchemy.dialects import postgresql from sqlalchemy.orm import InstrumentedAttribute +from testgen.common.enums import JobKey from testgen.common.models import Base, get_current_session -from testgen.common.models.entity import ENTITY_HASH_FUNCS from testgen.common.models.test_definition import TestDefinition from testgen.common.models.test_suite import TestSuite @@ -18,6 +17,8 @@ RUN_MONITORS_JOB_KEY = "run-monitors" RUN_PROFILE_JOB_KEY = "run-profile" +SCHEDULABLE_JOB_KEYS: frozenset[JobKey] = frozenset({JobKey.run_profile, JobKey.run_tests}) + class JobSchedule(Base): __tablename__ = "job_schedules" @@ -32,13 +33,22 @@ class JobSchedule(Base): active: bool = Column(Boolean, default=True) @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) def get(cls, *clauses) -> Self | None: query = select(cls).where(*clauses) return get_current_session().scalars(query).first() @classmethod def select_where(cls, *clauses, order_by: str | InstrumentedAttribute | None = None) -> Iterable[Self]: + query = select(cls).where(*clauses) + if order_by is not None: + query = query.order_by(order_by) + return get_current_session().scalars(query).all() + + @classmethod + def select_runnable(cls, *clauses, order_by: str | InstrumentedAttribute | None = None) -> Iterable[Self]: + """Schedules the scheduler should dispatch: active rows, and (for test/monitor runs) + only when the linked test suite has at least one test definition. + """ test_job_keys = [RUN_TESTS_JOB_KEY, RUN_MONITORS_JOB_KEY] test_definitions_count = ( select(cls.id) @@ -73,6 +83,28 @@ def update_active(cls, job_id: str | UUID, active: bool) -> None: def count(cls): return get_current_session().query(cls).count() + @classmethod + def list_for_project( + cls, + project_code: str, + *extra_filters, + key_filter: Iterable[JobKey] | None = None, + page: int = 1, + limit: int = 20, + ) -> tuple[list[Self], int]: + """List schedules for a project with optional key filter and pagination. + + Returns both active and paused rows. Defaults ``key_filter`` to + ``SCHEDULABLE_JOB_KEYS`` (``run_profile``, ``run_tests``); pass an explicit + ``key_filter`` to include other kinds. + """ + session = get_current_session() + keys = list(key_filter) if key_filter is not None else list(SCHEDULABLE_JOB_KEYS) + query = select(cls).where(cls.project_code == project_code, cls.key.in_(keys), *extra_filters) + total = session.scalar(select(func.count()).select_from(query.subquery())) + items = session.scalars(query.order_by(cls.key, cls.id).offset((page - 1) * limit).limit(limit)).all() + return list(items), total or 0 + @classmethod def select_active_by_kwargs( cls, diff --git a/testgen/common/models/test_run.py b/testgen/common/models/test_run.py index a001f6ce..27e52a8b 100644 --- a/testgen/common/models/test_run.py +++ b/testgen/common/models/test_run.py @@ -48,6 +48,7 @@ class TestRunMinimal(EntityMinimal): class TestRunSummary(EntityMinimal): job_execution_id: UUID test_run_id: UUID | None + job_schedule_id: UUID | None status: JobStatus created_at: datetime started_at: datetime | None @@ -256,6 +257,7 @@ def select_summary( SELECT je.id AS job_execution_id, tr.id AS test_run_id, + je.job_schedule_id, je.status, je.created_at, je.started_at, diff --git a/testgen/mcp/server.py b/testgen/mcp/server.py index 7b573679..838ceb26 100644 --- a/testgen/mcp/server.py +++ b/testgen/mcp/server.py @@ -174,6 +174,14 @@ def build_mcp_server( hygiene_issue_types_resource, test_types_resource, ) + from testgen.mcp.tools.schedules import ( + create_profiling_schedule, + create_test_run_schedule, + delete_schedule, + get_schedule, + list_schedules, + update_schedule, + ) from testgen.mcp.tools.source_data import get_source_data, get_source_data_query from testgen.mcp.tools.test_definitions import ( bulk_update_tests, @@ -264,6 +272,12 @@ def safe_prompt(fn): safe_tool(get_hygiene_issue) safe_tool(search_hygiene_issues) safe_tool(update_hygiene_issue) + safe_tool(create_profiling_schedule) + safe_tool(create_test_run_schedule) + safe_tool(list_schedules) + safe_tool(get_schedule) + safe_tool(update_schedule) + safe_tool(delete_schedule) # Resources safe_resource("testgen://test-types", test_types_resource) diff --git a/testgen/mcp/tools/common.py b/testgen/mcp/tools/common.py index 6017e449..acc382b6 100644 --- a/testgen/mcp/tools/common.py +++ b/testgen/mcp/tools/common.py @@ -16,7 +16,7 @@ ) from testgen.common.models.hygiene_issue import HygieneIssueType from testgen.common.models.profiling_run import ProfilingRun -from testgen.common.models.scheduler import JobSchedule +from testgen.common.models.scheduler import SCHEDULABLE_JOB_KEYS, JobSchedule from testgen.common.models.table_group import TableGroup from testgen.common.models.test_definition import TestDefinition, TestType from testgen.common.models.test_result import TestResultStatus @@ -391,3 +391,17 @@ def resolve_test_definition(test_definition_id: str) -> TestDefinition: if td is None: raise MCPResourceNotAccessible("Test definition", test_definition_id) return td + + +def resolve_schedule(schedule_id: str) -> JobSchedule: + """Resolve a user-managed schedule ID, collapsing missing-or-inaccessible into one error path.""" + sched_uuid = parse_uuid(schedule_id, "schedule_id") + perms = get_project_permissions() + sched = JobSchedule.get( + JobSchedule.id == sched_uuid, + JobSchedule.key.in_(SCHEDULABLE_JOB_KEYS), + JobSchedule.project_code.in_(perms.allowed_codes), + ) + if sched is None: + raise MCPResourceNotAccessible("Schedule", schedule_id) + return sched diff --git a/testgen/mcp/tools/profiling.py b/testgen/mcp/tools/profiling.py index d8ac35a2..79ed638c 100644 --- a/testgen/mcp/tools/profiling.py +++ b/testgen/mcp/tools/profiling.py @@ -562,6 +562,8 @@ def _render_pending_profiling_je(doc: MdDoc, je: JobExecution, label: str) -> No status_label = ProfilingRunSummary.STATUS_LABEL.get(je.status, je.status) doc.heading(3, f"{label} — {status_label}") doc.field("Job ID", je.id, code=True) + if je.job_schedule_id is not None: + doc.field("Schedule", je.job_schedule_id, code=True) doc.field("Submitted", je.created_at) doc.field("Started", je.started_at or "—") doc.field("Ended", je.completed_at or "In progress") @@ -571,6 +573,8 @@ def _render_profiling_run_section(doc: MdDoc, run: ProfilingRunSummary) -> None: title = run.table_groups_name or run.profiling_run_id or run.job_execution_id doc.heading(2, f"{title} — {run.status_label}") doc.field("Job ID", run.job_execution_id, code=True) + if run.job_schedule_id is not None: + doc.field("Schedule", run.job_schedule_id, code=True) doc.field("Submitted", run.created_at) doc.field("Started", run.started_at or "—") doc.field("Ended", run.completed_at or "In progress") diff --git a/testgen/mcp/tools/schedules.py b/testgen/mcp/tools/schedules.py new file mode 100644 index 00000000..73975aac --- /dev/null +++ b/testgen/mcp/tools/schedules.py @@ -0,0 +1,441 @@ +"""MCP tools for managing recurring TestGen schedules — profiling and test-run schedules.""" + +from datetime import datetime +from enum import StrEnum + +from sqlalchemy import select + +from testgen.common.cron_service import describe_cron, get_cron_sample +from testgen.common.enums import JobKey +from testgen.common.models import get_current_session, with_database_session +from testgen.common.models.job_execution import JobExecution +from testgen.common.models.scheduler import JobSchedule +from testgen.common.models.table_group import TableGroup +from testgen.common.models.test_run import TestRunSummary # STATUS_LABEL is shared with ProfilingRunSummary +from testgen.common.models.test_suite import TestSuite +from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError +from testgen.mcp.permissions import get_project_permissions, mcp_permission +from testgen.mcp.tools.common import ( + DocGroup, + format_page_footer, + format_page_info, + format_run_duration, + resolve_schedule, + resolve_table_group, + resolve_test_suite, + validate_limit, + validate_page, +) +from testgen.mcp.tools.markdown import MdDoc + +_DOC_GROUP = DocGroup.TRIGGER + + +class ScheduleType(StrEnum): + profiling = "profiling" + test_execution = "test_execution" + + +_SCHEDULE_TYPE_TO_JOB_KEY: dict[ScheduleType, JobKey] = { + ScheduleType.profiling: JobKey.run_profile, + ScheduleType.test_execution: JobKey.run_tests, +} + + +def _kind_display(key: str) -> str: + """User-facing label for a schedule's job kind.""" + if key == JobKey.run_profile: + return "Profiling Run" + return "Test Run" + +# --------------------------------------------------------------------------- +# Validation + rendering helpers +# --------------------------------------------------------------------------- + + +def _validate_cron(cron_expression: str, cron_tz: str) -> str: + """Validate cron expression + timezone. Returns the human-readable description.""" + if not cron_expression: + raise MCPUserError("`cron_expression` is required.") + if not cron_tz: + raise MCPUserError("`cron_tz` is required (IANA name, e.g. `UTC`).") + sample = get_cron_sample(cron_expression, cron_tz, sample_count=1) + if "error" in sample: + raise MCPUserError(f"Invalid cron expression or timezone: {sample['error']}") + return sample["readable_expr"] + + +def _parse_schedule_type(value: str) -> ScheduleType: + try: + return ScheduleType(value) + except ValueError as err: + valid = ", ".join(t.value for t in ScheduleType) + raise MCPUserError(f"Invalid schedule_type `{value}`. Valid values: {valid}") from err + + +def _linked_kind_label(key: str) -> str: + """Field label for the linked entity row, based on the schedule's ``key``.""" + if key == JobKey.run_profile: + return "Table Group" + return "Test Suite" + + +def _linked_entity_id(sched: JobSchedule) -> str | None: + """Extract the linked entity UUID from ``kwargs``. ``None`` if the row is malformed.""" + if sched.key == JobKey.run_profile: + return sched.kwargs.get("table_group_id") + return sched.kwargs.get("test_suite_id") + + +def _format_linked(sched: JobSchedule, name: str | None) -> str: + """Combined ``: `name` (ID: `uuid`)`` line used by both detail block and list rows.""" + linked_id = _linked_entity_id(sched) + name_part = f"`{name}`" if name else "—" + id_part = f" (ID: `{linked_id}`)" if linked_id else "" + return f"{name_part}{id_part}" + + +def _next_run(sched: JobSchedule) -> datetime | None: + try: + return sched.get_sample_triggering_timestamps(1)[0] + except Exception: + return None + + +def _render_schedule( + doc: MdDoc, + sched: JobSchedule, + *, + linked_name: str | None, + include_next_runs: int = 1, +) -> None: + doc.field("Schedule ID", sched.id, code=True) + doc.field("Type", _kind_display(sched.key)) + doc.field(_linked_kind_label(sched.key), _format_linked(sched, linked_name)) + doc.field("Cron expression", sched.cron_expr, code=True) + if (readable := describe_cron(sched.cron_expr)) is not None: + doc.field("Cron description", readable) + doc.field("Timezone", sched.cron_tz) + doc.field("Status", "Active" if sched.active else "Paused") + if include_next_runs > 0: + try: + next_times = sched.get_sample_triggering_timestamps(include_next_runs) + except Exception: + next_times = [] + if next_times: + label = "Next run" if include_next_runs == 1 else "Next runs" + doc.field(label, ", ".join(_format_dt(t) for t in next_times)) + + +def _format_dt(value: datetime | None) -> str: + if value is None: + return "—" + return value.strftime("%Y-%m-%d %H:%M %Z") or value.strftime("%Y-%m-%d %H:%M") + + +def _resolve_linked_names(schedules: list[JobSchedule]) -> dict[tuple[str, str], str]: + """Batch-fetch linked-entity names for a list of schedules. Avoids N+1. + + Returns a dict keyed by (kind, id) where kind ∈ {'tg', 'suite'} and id is the UUID string. + """ + session = get_current_session() + tg_ids: set[str] = set() + suite_ids: set[str] = set() + for sched in schedules: + linked_id = _linked_entity_id(sched) + if linked_id is None: + continue + if sched.key == JobKey.run_profile: + tg_ids.add(linked_id) + else: + suite_ids.add(linked_id) + + names: dict[tuple[str, str], str] = {} + if tg_ids: + rows = session.execute( + select(TableGroup.id, TableGroup.table_groups_name).where(TableGroup.id.in_(tg_ids)) + ).all() + for row_id, row_name in rows: + names[("tg", str(row_id))] = row_name + if suite_ids: + rows = session.execute( + select(TestSuite.id, TestSuite.test_suite).where(TestSuite.id.in_(suite_ids)) + ).all() + for row_id, row_name in rows: + names[("suite", str(row_id))] = row_name + return names + + +def _linked_name(sched: JobSchedule, names: dict[tuple[str, str], str]) -> str | None: + linked_id = _linked_entity_id(sched) + if linked_id is None: + return None + kind = "tg" if sched.key == JobKey.run_profile else "suite" + return names.get((kind, linked_id)) + + +# --------------------------------------------------------------------------- +# Write tools +# --------------------------------------------------------------------------- + + +@with_database_session +@mcp_permission("edit") +def create_profiling_schedule( + table_group_id: str, + cron_expression: str, + cron_tz: str = "UTC", + active: bool = True, +) -> str: + """Create a recurring profiling schedule for a table group. + + Args: + table_group_id: UUID of the table group to profile, e.g. from ``get_data_inventory``. + cron_expression: Five-field cron expression, e.g. ``0 3 * * *`` for daily at 03:00. + cron_tz: IANA timezone name (e.g. ``America/New_York``). Defaults to ``UTC``. + active: Whether the schedule should start active. Defaults to ``True``. + """ + table_group = resolve_table_group(table_group_id) + _validate_cron(cron_expression, cron_tz) + sched = JobSchedule( + project_code=table_group.project_code, + key=JobKey.run_profile, + kwargs={"table_group_id": str(table_group.id)}, + cron_expr=cron_expression, + cron_tz=cron_tz, + active=active, + ) + sched.save() + get_current_session().flush() + + doc = MdDoc() + doc.heading(1, f"Profiling schedule created for `{table_group.table_groups_name}`") + _render_schedule(doc, sched, linked_name=table_group.table_groups_name) + return doc.render() + + +@with_database_session +@mcp_permission("edit") +def create_test_run_schedule( + test_suite_id: str, + cron_expression: str, + cron_tz: str = "UTC", + active: bool = True, +) -> str: + """Create a recurring test-run schedule for a test suite. + + Args: + test_suite_id: UUID of the test suite to run, e.g. from ``list_test_suites``. + cron_expression: Five-field cron expression, e.g. ``0 6 * * 1`` for Mondays at 06:00. + cron_tz: IANA timezone name (e.g. ``America/New_York``). Defaults to ``UTC``. + active: Whether the schedule should start active. Defaults to ``True``. + """ + suite = resolve_test_suite(test_suite_id) + _validate_cron(cron_expression, cron_tz) + sched = JobSchedule( + project_code=suite.project_code, + key=JobKey.run_tests, + kwargs={"test_suite_id": str(suite.id)}, + cron_expr=cron_expression, + cron_tz=cron_tz, + active=active, + ) + sched.save() + get_current_session().flush() + + doc = MdDoc() + doc.heading(1, f"Test-run schedule created for `{suite.test_suite}`") + _render_schedule(doc, sched, linked_name=suite.test_suite) + return doc.render() + + +@with_database_session +@mcp_permission("edit") +def update_schedule( + schedule_id: str, + cron_expression: str | None = None, + cron_tz: str | None = None, + active: bool | None = None, +) -> str: + """Update a schedule's cron, timezone, or active state. Atomic — no partial save. + + The job type and linked configuration are immutable — delete and recreate to change them. + + Args: + schedule_id: UUID of the schedule, e.g. from ``list_schedules``. + cron_expression: New cron expression. Omit to leave unchanged. + cron_tz: New IANA timezone. Omit to leave unchanged. + active: ``True`` to resume, ``False`` to pause. Omit to leave unchanged. + """ + if cron_expression is None and cron_tz is None and active is None: + raise MCPUserError("No fields supplied to update.") + + sched = resolve_schedule(schedule_id) + + new_expr = cron_expression if cron_expression is not None else sched.cron_expr + new_tz = cron_tz if cron_tz is not None else sched.cron_tz + if cron_expression is not None or cron_tz is not None: + _validate_cron(new_expr, new_tz) + + changes: list[tuple[str, object, object]] = [] + if cron_expression is not None and cron_expression != sched.cron_expr: + changes.append(("Cron expression", sched.cron_expr, cron_expression)) + sched.cron_expr = cron_expression + if cron_tz is not None and cron_tz != sched.cron_tz: + changes.append(("Timezone", sched.cron_tz, cron_tz)) + sched.cron_tz = cron_tz + if active is not None and active != sched.active: + before = "Active" if sched.active else "Paused" + after = "Active" if active else "Paused" + changes.append(("Status", before, after)) + sched.active = active + + sched.save() + get_current_session().flush() + + doc = MdDoc() + doc.heading(1, "Schedule updated") + doc.field("Schedule ID", sched.id, code=True) + if not changes: + doc.text("No fields changed — supplied values matched the current state.") + return doc.render() + doc.table(["Field", "Before", "After"], [list(c) for c in changes]) + return doc.render() + + +@with_database_session +@mcp_permission("edit") +def delete_schedule(schedule_id: str) -> str: + """Delete a schedule. Past executions remain accessible via ``list_test_runs`` / ``list_profiling_runs``. + + Args: + schedule_id: UUID of the schedule, e.g. from ``list_schedules``. + """ + sched = resolve_schedule(schedule_id) + JobSchedule.delete(sched.id) + + doc = MdDoc() + doc.heading(1, "Schedule deleted") + doc.field("Schedule ID", sched.id, code=True) + return doc.render() + + +# --------------------------------------------------------------------------- +# Read tools +# --------------------------------------------------------------------------- + + +@with_database_session +@mcp_permission("view") +def list_schedules( + project_code: str, + schedule_type: str | None = None, + limit: int = 20, + page: int = 1, +) -> str: + """List schedules for a project — profiling and test-run schedules. + + Args: + project_code: Project to scope to, e.g. from ``list_projects``. + schedule_type: Optional filter — ``profiling`` or ``test_execution``. + limit: Max rows per page. Defaults to 20. + page: 1-indexed page number. Defaults to 1. + """ + validate_page(page) + validate_limit(limit, 100) + + perms = get_project_permissions() + if project_code not in perms.allowed_codes: + raise MCPResourceNotAccessible("Project", project_code) + + key_filter: list[JobKey] | None = None + if schedule_type is not None: + st_enum = _parse_schedule_type(schedule_type) + key_filter = [_SCHEDULE_TYPE_TO_JOB_KEY[st_enum]] + + schedules, total = JobSchedule.list_for_project( + project_code, + key_filter=key_filter, + page=page, + limit=limit, + ) + + doc = MdDoc() + doc.heading(1, f"Schedules — `{project_code}`") + info = format_page_info(total, page, limit) + if info: + doc.text(info) + if not schedules: + doc.text("_No schedules._") + return doc.render() + + linked_names = _resolve_linked_names(schedules) + rows: list[list[object]] = [] + for sched in schedules: + rows.append([ + sched.id, + _kind_display(sched.key), + f"{_linked_kind_label(sched.key)}: {_format_linked(sched, _linked_name(sched, linked_names))}", + sched.cron_expr, + sched.cron_tz, + "Active" if sched.active else "Paused", + _format_dt(_next_run(sched)), + ]) + doc.table( + ["Schedule ID", "Type", "Details", "Cron", "Timezone", "Status", "Next run"], + rows, + code=[0, 3], + ) + footer = format_page_footer(total, page, limit) + if footer: + doc.text(footer) + return doc.render() + + +@with_database_session +@mcp_permission("view") +def get_schedule(schedule_id: str) -> str: + """Get full details for a schedule, including the last five execution attempts. + + Args: + schedule_id: UUID of the schedule, e.g. from ``list_schedules``. + """ + sched = resolve_schedule(schedule_id) + linked_names = _resolve_linked_names([sched]) + linked_name = _linked_name(sched, linked_names) + + doc = MdDoc() + doc.heading(1, "Schedule") + _render_schedule(doc, sched, linked_name=linked_name, include_next_runs=3) + + history = get_current_session().scalars( + select(JobExecution) + .where(JobExecution.job_schedule_id == sched.id) + .order_by(JobExecution.created_at.desc()) + .limit(5) + ).all() + + doc.heading(2, "Recent executions") + if not history: + doc.text("_No executions yet._") + return doc.render() + + rows: list[list[object]] = [] + for je in history: + rows.append([ + je.id, + TestRunSummary.STATUS_LABEL.get(je.status, je.status), + je.started_at, + je.completed_at, + format_run_duration(je.started_at, je.completed_at), + ]) + doc.table( + ["Job execution ID", "Status", "Started", "Completed", "Duration"], + rows, + code=[0], + ) + doc.text( + "_Showing the 5 most recent executions._ " + "Use `list_test_runs` or `list_profiling_runs` for full history." + ) + return doc.render() diff --git a/testgen/mcp/tools/test_runs.py b/testgen/mcp/tools/test_runs.py index 7ca60dcc..415571e6 100644 --- a/testgen/mcp/tools/test_runs.py +++ b/testgen/mcp/tools/test_runs.py @@ -279,6 +279,8 @@ def _render_pending_je(doc: MdDoc, je: JobExecution, label: str) -> None: status_label = TestRunSummary.STATUS_LABEL.get(je.status, je.status) doc.heading(3, f"{label} — {status_label}") doc.field("Job ID", je.id, code=True) + if je.job_schedule_id is not None: + doc.field("Schedule", je.job_schedule_id, code=True) doc.field("Submitted", je.created_at) doc.field("Started", je.started_at or "—") doc.field("Ended", je.completed_at or "In progress") @@ -288,6 +290,8 @@ def _render_test_run_section(doc: MdDoc, run: TestRunSummary) -> None: title = run.test_suite or run.project_code doc.heading(2, f"{title} — {run.status_label}") doc.field("Job ID", run.job_execution_id, code=True) + if run.job_schedule_id is not None: + doc.field("Schedule", run.job_schedule_id, code=True) if run.test_suite: doc.field("Test suite", run.test_suite) if run.table_groups_name: diff --git a/testgen/scheduler/cli_scheduler.py b/testgen/scheduler/cli_scheduler.py index d5da99fe..cfa18657 100644 --- a/testgen/scheduler/cli_scheduler.py +++ b/testgen/scheduler/cli_scheduler.py @@ -48,7 +48,7 @@ def get_jobs(self) -> Iterable[CliJob]: self.reload_timer.start() jobs = {} - for job_model in JobSchedule.select_where(): + for job_model in JobSchedule.select_runnable(): if job_model.key not in JOB_DISPATCH: LOG.error("Job '%s' scheduled but not registered", job_model.key) continue diff --git a/testgen/ui/utils.py b/testgen/ui/utils.py index fe097761..270de953 100644 --- a/testgen/ui/utils.py +++ b/testgen/ui/utils.py @@ -1,20 +1,10 @@ -import zoneinfo from collections.abc import Callable -from datetime import datetime from typing import TypedDict -import cron_converter -import cron_descriptor - +from testgen.common.cron_service import get_cron_sample from testgen.ui.session import temp_value -class CronSample(TypedDict): - id: str | None - error: str | None - samples: list[str] | list[int] | None - readable_expr: str | None - class CronSampleHandlerPayload(TypedDict): tz: str cron_expr: str @@ -23,33 +13,6 @@ class CronSampleHandlerPayload(TypedDict): CronSampleCallback = Callable[[CronSampleHandlerPayload], None] -def get_cron_sample( - cron_expr: str, - cron_tz: str, - sample_count: int, - *, - reference_time: datetime | None = None, - formatted: bool = False, -) -> CronSample: - try: - cron_obj = cron_converter.Cron(cron_expr) - cron_schedule = cron_obj.schedule(reference_time or datetime.now(zoneinfo.ZoneInfo(cron_tz))) - readble_cron_schedule = cron_descriptor.get_description(cron_expr) - if formatted: - samples = [cron_schedule.next().strftime("%a %b %-d, %-I:%M %p") for _ in range(sample_count)] - else: - samples = [int(cron_schedule.next().timestamp()) for _ in range(sample_count)] - except ValueError as e: - return {"error": str(e)} - except Exception as e: - return {"error": "Error validating the Cron expression"} - else: - return { - "samples": samples, - "readable_expr": readble_cron_schedule, - } - - def get_cron_sample_handler(key: str, *, sample_count: int = 3) -> tuple[dict | None, CronSampleCallback]: cron_sample_result, set_cron_sample = temp_value(key, default={}) diff --git a/testgen/ui/views/dialogs/manage_schedules.py b/testgen/ui/views/dialogs/manage_schedules.py index 5aeeff5d..4d917f7c 100644 --- a/testgen/ui/views/dialogs/manage_schedules.py +++ b/testgen/ui/views/dialogs/manage_schedules.py @@ -86,7 +86,7 @@ def on_resume(self, item: dict) -> None: st.session_state.pop(RESULT_KEY, None) def on_cron_sample(self, payload: dict) -> None: - from testgen.ui.utils import get_cron_sample + from testgen.common.cron_service import get_cron_sample sample = get_cron_sample(payload["cron_expr"], payload["tz"], CRON_SAMPLE_COUNT, formatted=True) st.session_state[CRON_SAMPLE_KEY] = sample diff --git a/testgen/ui/views/monitors_dashboard.py b/testgen/ui/views/monitors_dashboard.py index e48f43e9..b7eaaad4 100644 --- a/testgen/ui/views/monitors_dashboard.py +++ b/testgen/ui/views/monitors_dashboard.py @@ -7,6 +7,7 @@ import streamlit as st from testgen.commands.test_generation import run_monitor_generation +from testgen.common.cron_service import get_cron_sample from testgen.common.freshness_service import add_business_minutes, get_schedule_params, resolve_holiday_dates from testgen.common.models import get_current_session, with_database_session from testgen.common.models.notification_settings import ( @@ -27,7 +28,7 @@ from testgen.ui.services.query_cache import get_project_summary, get_test_type_summaries from testgen.ui.services.rerun_service import safe_rerun from testgen.ui.session import session, temp_value -from testgen.ui.utils import dict_from_kv, get_cron_sample, get_cron_sample_handler +from testgen.ui.utils import dict_from_kv, get_cron_sample_handler from testgen.ui.views.dialogs.manage_notifications import NotificationSettingsDialogBase from testgen.utils import make_json_safe diff --git a/tests/unit/mcp/test_tools_schedules.py b/tests/unit/mcp/test_tools_schedules.py new file mode 100644 index 00000000..98d17d55 --- /dev/null +++ b/tests/unit/mcp/test_tools_schedules.py @@ -0,0 +1,358 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from testgen.common.enums import JobKey +from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError + + +def _make_table_group(project_code="demo", name="orders_tg"): + tg = MagicMock() + tg.id = uuid4() + tg.project_code = project_code + tg.table_groups_name = name + return tg + + +def _make_suite(project_code="demo", name="suite_a", is_monitor=False): + suite = MagicMock() + suite.id = uuid4() + suite.project_code = project_code + suite.test_suite = name + suite.is_monitor = is_monitor + return suite + + +def _make_sched(*, key=None, active=True, project_code="demo", linked_id=None): + sched = MagicMock() + sched.id = uuid4() + sched.project_code = project_code + sched.key = key or JobKey.run_profile.value + sched.cron_expr = "0 3 * * *" + sched.cron_tz = "UTC" + sched.active = active + if sched.key == JobKey.run_profile.value: + sched.kwargs = {"table_group_id": linked_id or str(uuid4())} + else: + sched.kwargs = {"test_suite_id": linked_id or str(uuid4())} + sched.get_sample_triggering_timestamps.return_value = [datetime(2026, 5, 19, 3, 0)] + return sched + + +# --------------------------------------------------------------------------- +# create_profiling_schedule +# --------------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.schedules.JobSchedule") +@patch("testgen.mcp.tools.schedules.resolve_table_group") +def test_create_profiling_schedule_happy_path(mock_resolve_tg, mock_sched_cls, db_session_mock): + tg = _make_table_group() + mock_resolve_tg.return_value = tg + saved = _make_sched(linked_id=str(tg.id)) + mock_sched_cls.return_value = saved + + from testgen.mcp.tools.schedules import create_profiling_schedule + + result = create_profiling_schedule( + table_group_id=str(tg.id), + cron_expression="0 3 * * *", + cron_tz="UTC", + ) + + assert "Profiling schedule created" in result + assert "orders_tg" in result + assert "`0 3 * * *`" in result + saved.save.assert_called_once() + + +@patch("testgen.mcp.tools.schedules.resolve_table_group") +def test_create_profiling_schedule_invalid_cron(mock_resolve_tg, db_session_mock): + mock_resolve_tg.return_value = _make_table_group() + + from testgen.mcp.tools.schedules import create_profiling_schedule + + with pytest.raises(MCPUserError) as exc: + create_profiling_schedule( + table_group_id=str(uuid4()), + cron_expression="not a cron", + cron_tz="UTC", + ) + assert "Invalid cron" in str(exc.value) + + +@patch("testgen.mcp.tools.schedules.resolve_table_group") +def test_create_profiling_schedule_invalid_timezone(mock_resolve_tg, db_session_mock): + mock_resolve_tg.return_value = _make_table_group() + + from testgen.mcp.tools.schedules import create_profiling_schedule + + with pytest.raises(MCPUserError) as exc: + create_profiling_schedule( + table_group_id=str(uuid4()), + cron_expression="0 3 * * *", + cron_tz="Not/A_Real_Timezone", + ) + assert "Invalid cron" in str(exc.value) + + +@patch("testgen.mcp.tools.schedules.resolve_table_group") +def test_create_profiling_schedule_empty_cron_rejected(mock_resolve_tg, db_session_mock): + mock_resolve_tg.return_value = _make_table_group() + + from testgen.mcp.tools.schedules import create_profiling_schedule + + with pytest.raises(MCPUserError) as exc: + create_profiling_schedule(table_group_id=str(uuid4()), cron_expression="") + assert "cron_expression" in str(exc.value) + + +@patch("testgen.mcp.tools.schedules.resolve_table_group") +def test_create_profiling_schedule_empty_tz_rejected(mock_resolve_tg, db_session_mock): + mock_resolve_tg.return_value = _make_table_group() + + from testgen.mcp.tools.schedules import create_profiling_schedule + + with pytest.raises(MCPUserError) as exc: + create_profiling_schedule( + table_group_id=str(uuid4()), cron_expression="0 3 * * *", cron_tz="" + ) + assert "cron_tz" in str(exc.value) + + +# --------------------------------------------------------------------------- +# create_test_run_schedule +# --------------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.schedules.JobSchedule") +@patch("testgen.mcp.tools.schedules.resolve_test_suite") +def test_create_test_run_schedule_happy_path(mock_resolve_suite, mock_sched_cls, db_session_mock): + suite = _make_suite() + mock_resolve_suite.return_value = suite + saved = _make_sched(key=JobKey.run_tests.value, linked_id=str(suite.id)) + mock_sched_cls.return_value = saved + + from testgen.mcp.tools.schedules import create_test_run_schedule + + result = create_test_run_schedule( + test_suite_id=str(suite.id), + cron_expression="0 6 * * 1", + cron_tz="UTC", + ) + + assert "Test-run schedule created" in result + assert "suite_a" in result + saved.save.assert_called_once() + + +@patch("testgen.mcp.tools.schedules.resolve_test_suite") +def test_create_test_run_schedule_monitor_suite_rejected(mock_resolve_suite, db_session_mock): + mock_resolve_suite.side_effect = MCPResourceNotAccessible("Test suite", "abc") + + from testgen.mcp.tools.schedules import create_test_run_schedule + + with pytest.raises(MCPResourceNotAccessible): + create_test_run_schedule( + test_suite_id=str(uuid4()), + cron_expression="0 6 * * 1", + ) + + +# --------------------------------------------------------------------------- +# list_schedules +# --------------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.schedules._resolve_linked_names") +@patch("testgen.mcp.tools.schedules.JobSchedule") +def test_list_schedules_basic(mock_sched_cls, mock_linked, db_session_mock): + sched_a = _make_sched(key=JobKey.run_profile.value) + sched_b = _make_sched(key=JobKey.run_tests.value) + mock_sched_cls.list_for_project.return_value = ([sched_a, sched_b], 2) + mock_linked.return_value = { + ("tg", sched_a.kwargs["table_group_id"]): "orders_tg", + ("suite", sched_b.kwargs["test_suite_id"]): "suite_a", + } + + from testgen.mcp.tools.schedules import list_schedules + + result = list_schedules(project_code="demo") + + assert "Schedules" in result + assert "Profiling Run" in result + assert "Test Run" in result + assert "orders_tg" in result + assert "suite_a" in result + + +@patch("testgen.mcp.tools.schedules.JobSchedule") +def test_list_schedules_empty(mock_sched_cls, db_session_mock): + mock_sched_cls.list_for_project.return_value = ([], 0) + + from testgen.mcp.tools.schedules import list_schedules + + result = list_schedules(project_code="demo") + assert "No schedules" in result + + +@patch("testgen.mcp.tools.schedules._resolve_linked_names") +@patch("testgen.mcp.tools.schedules.JobSchedule") +def test_list_schedules_type_filter_maps_to_job_key(mock_sched_cls, mock_linked, db_session_mock): + sched = _make_sched(key=JobKey.run_profile.value) + mock_sched_cls.list_for_project.return_value = ([sched], 1) + mock_linked.return_value = {} + + from testgen.mcp.tools.schedules import list_schedules + + list_schedules(project_code="demo", schedule_type="profiling") + + call_kwargs = mock_sched_cls.list_for_project.call_args + assert call_kwargs.kwargs["key_filter"] == [JobKey.run_profile.value] + + +def test_list_schedules_invalid_schedule_type(db_session_mock): + from testgen.mcp.tools.schedules import list_schedules + + with pytest.raises(MCPUserError) as exc: + list_schedules(project_code="demo", schedule_type="not-a-type") + assert "Invalid schedule_type" in str(exc.value) + + +def test_list_schedules_project_not_accessible(db_session_mock): + from testgen.mcp.tools.schedules import list_schedules + + with pytest.raises(MCPResourceNotAccessible): + list_schedules(project_code="other_project") + + +# --------------------------------------------------------------------------- +# get_schedule +# --------------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.schedules.get_current_session") +@patch("testgen.mcp.tools.schedules._resolve_linked_names") +@patch("testgen.mcp.tools.schedules.resolve_schedule") +def test_get_schedule_no_executions(mock_resolve, mock_linked, mock_session, db_session_mock): + sched = _make_sched(key=JobKey.run_profile.value) + mock_resolve.return_value = sched + mock_linked.return_value = {("tg", sched.kwargs["table_group_id"]): "orders_tg"} + session = MagicMock() + session.scalars.return_value.all.return_value = [] + mock_session.return_value = session + + from testgen.mcp.tools.schedules import get_schedule + + result = get_schedule(schedule_id=str(sched.id)) + assert "orders_tg" in result + assert "No executions yet" in result + + +@patch("testgen.mcp.tools.schedules.get_current_session") +@patch("testgen.mcp.tools.schedules._resolve_linked_names") +@patch("testgen.mcp.tools.schedules.resolve_schedule") +def test_get_schedule_with_executions(mock_resolve, mock_linked, mock_session, db_session_mock): + sched = _make_sched(key=JobKey.run_profile.value) + mock_resolve.return_value = sched + mock_linked.return_value = {("tg", sched.kwargs["table_group_id"]): "orders_tg"} + + je = MagicMock() + je.id = uuid4() + je.status = "Completed" + je.created_at = datetime(2026, 5, 18, 3, 0) + je.started_at = datetime(2026, 5, 18, 3, 0) + je.completed_at = datetime(2026, 5, 18, 3, 12) + session = MagicMock() + session.scalars.return_value.all.return_value = [je] + mock_session.return_value = session + + from testgen.mcp.tools.schedules import get_schedule + + result = get_schedule(schedule_id=str(sched.id)) + assert "Recent executions" in result + assert str(je.id) in result + + +# --------------------------------------------------------------------------- +# update_schedule +# --------------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.schedules.resolve_schedule") +def test_update_schedule_happy_path_diff(mock_resolve, db_session_mock): + sched = _make_sched(key=JobKey.run_profile.value, active=True) + mock_resolve.return_value = sched + + from testgen.mcp.tools.schedules import update_schedule + + result = update_schedule(schedule_id=str(sched.id), active=False) + + assert "Schedule updated" in result + assert "Active" in result and "Paused" in result + sched.save.assert_called_once() + + +def test_update_schedule_empty_payload_rejected(db_session_mock): + from testgen.mcp.tools.schedules import update_schedule + + with pytest.raises(MCPUserError) as exc: + update_schedule(schedule_id=str(uuid4())) + assert "No fields supplied" in str(exc.value) + + +@patch("testgen.mcp.tools.schedules.resolve_schedule") +def test_update_schedule_invalid_cron(mock_resolve, db_session_mock): + sched = _make_sched(key=JobKey.run_profile.value) + mock_resolve.return_value = sched + + from testgen.mcp.tools.schedules import update_schedule + + with pytest.raises(MCPUserError) as exc: + update_schedule(schedule_id=str(sched.id), cron_expression="garbage") + assert "Invalid cron" in str(exc.value) + sched.save.assert_not_called() + + +@patch("testgen.mcp.tools.schedules.resolve_schedule") +def test_update_schedule_monitor_schedule_not_accessible(mock_resolve, db_session_mock): + """resolve_schedule filters out monitor schedules — caller sees the unified not-accessible error.""" + mock_resolve.side_effect = MCPResourceNotAccessible("Schedule", "abc") + + from testgen.mcp.tools.schedules import update_schedule + + with pytest.raises(MCPResourceNotAccessible): + update_schedule(schedule_id=str(uuid4()), active=False) + + +# --------------------------------------------------------------------------- +# delete_schedule +# --------------------------------------------------------------------------- + + +@patch("testgen.mcp.tools.schedules.JobSchedule") +@patch("testgen.mcp.tools.schedules.resolve_schedule") +def test_delete_schedule_happy_path(mock_resolve, mock_sched_cls, db_session_mock): + sched = _make_sched(key=JobKey.run_profile.value) + mock_resolve.return_value = sched + + from testgen.mcp.tools.schedules import delete_schedule + + result = delete_schedule(schedule_id=str(sched.id)) + assert "Schedule deleted" in result + mock_sched_cls.delete.assert_called_once_with(sched.id) + + +@patch("testgen.mcp.tools.schedules.JobSchedule") +@patch("testgen.mcp.tools.schedules.resolve_schedule") +def test_delete_schedule_monitor_schedule_not_accessible(mock_resolve, mock_sched_cls, db_session_mock): + """resolve_schedule filters out monitor schedules — caller sees the unified not-accessible error.""" + mock_resolve.side_effect = MCPResourceNotAccessible("Schedule", "abc") + + from testgen.mcp.tools.schedules import delete_schedule + + with pytest.raises(MCPResourceNotAccessible): + delete_schedule(schedule_id=str(uuid4())) + mock_sched_cls.delete.assert_not_called() diff --git a/tests/unit/scheduler/test_scheduler_cli.py b/tests/unit/scheduler/test_scheduler_cli.py index d4008250..8dfa045d 100644 --- a/tests/unit/scheduler/test_scheduler_cli.py +++ b/tests/unit/scheduler/test_scheduler_cli.py @@ -42,7 +42,7 @@ def popen_mock(popen_proc_mock): @pytest.fixture def db_jobs(scheduler_instance): with ( - patch("testgen.scheduler.cli_scheduler.JobSchedule.select_where") as mock, + patch("testgen.scheduler.cli_scheduler.JobSchedule.select_runnable") as mock, ): yield mock From 55d6a796eb3a619ea4e519803b1fa6471b6940bd Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Tue, 19 May 2026 21:13:40 -0400 Subject: [PATCH 28/58] fix(salesforce): apply MR review feedback - Dialect: drop dead get_pool_class duplicate of flavor_service.get_engine_args. - Min_Date measure: explicit CAST({COLUMN_NAME}/BASELINE_VALUE AS DATE) to match the lookup query and Future_Date pattern. - LOV_All lookup: inner ORDER BY 1 LIMIT 1000 so output sort matches the measure. - Connection form: drop unreachable authMethod.val conditions inside requiredIf validators (the fields are already gated by an outer branch). Prune stale validityPerField entries when the auth-method radio switches, so the Next button enables after filling Client Credentials fields. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../database/salesforce_data360_dialect.py | 4 ---- .../test_types_LOV_All.yaml | 2 +- .../test_types_Min_Date.yaml | 2 +- .../static/js/components/connection_form.js | 20 ++++++++++++------- 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/testgen/common/database/salesforce_data360_dialect.py b/testgen/common/database/salesforce_data360_dialect.py index 43b44459..271f09f3 100644 --- a/testgen/common/database/salesforce_data360_dialect.py +++ b/testgen/common/database/salesforce_data360_dialect.py @@ -25,7 +25,6 @@ ) from salesforcecdpconnector.exceptions import Error as _CdpError from sqlalchemy.engine.default import DefaultDialect -from sqlalchemy.pool import StaticPool def _format_oauth_failure(grant_label: str, response) -> str: @@ -164,6 +163,3 @@ def initialize(self, connection): # Skip server-version detection and other introspection that # DefaultDialect.initialize() performs. pass - - def get_pool_class(self, _url): - return StaticPool diff --git a/testgen/template/dbsetup_test_types/test_types_LOV_All.yaml b/testgen/template/dbsetup_test_types/test_types_LOV_All.yaml index 59a07d79..cdd4bfda 100644 --- a/testgen/template/dbsetup_test_types/test_types_LOV_All.yaml +++ b/testgen/template/dbsetup_test_types/test_types_LOV_All.yaml @@ -217,6 +217,6 @@ test_types: sql_flavor: salesforce_data360 lookup_type: null lookup_query: |- - SELECT ARRAY_JOIN(ARRAY_AGG(DISTINCT "{COLUMN_NAME}"), '|') AS lov FROM "{TABLE_NAME}" HAVING ARRAY_JOIN(ARRAY_AGG(DISTINCT "{COLUMN_NAME}"), '|') <> {THRESHOLD_VALUE} LIMIT {LIMIT}; + SELECT ARRAY_JOIN(ARRAY_AGG(sub_val), '|') AS lov FROM (SELECT DISTINCT "{COLUMN_NAME}" AS sub_val FROM "{TABLE_NAME}" ORDER BY 1 LIMIT 1000) sub_lov HAVING ARRAY_JOIN(ARRAY_AGG(sub_val), '|') <> {THRESHOLD_VALUE} LIMIT {LIMIT}; error_type: Test Results test_templates: [] diff --git a/testgen/template/dbsetup_test_types/test_types_Min_Date.yaml b/testgen/template/dbsetup_test_types/test_types_Min_Date.yaml index a3f9cbe1..1ade805e 100644 --- a/testgen/template/dbsetup_test_types/test_types_Min_Date.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Min_Date.yaml @@ -121,7 +121,7 @@ test_types: test_type: Min_Date sql_flavor: salesforce_data360 measure: |- - SUM(CASE WHEN {COLUMN_NAME} < '{BASELINE_VALUE}' THEN 1 ELSE 0 END) + SUM(CASE WHEN CAST({COLUMN_NAME} AS DATE) < CAST('{BASELINE_VALUE}' AS DATE) THEN 1 ELSE 0 END) test_operator: '>' test_condition: |- {THRESHOLD_VALUE} diff --git a/testgen/ui/static/js/components/connection_form.js b/testgen/ui/static/js/components/connection_form.js index 0236996f..1fadaacf 100644 --- a/testgen/ui/static/js/components/connection_form.js +++ b/testgen/ui/static/js/components/connection_form.js @@ -1398,7 +1398,16 @@ const SalesforceData360Form = ( { label: 'Client Credentials Flow', value: 'client_credentials' }, ], value: authMethod, - onChange: (value) => authMethod.val = value, + onChange: (value) => { + authMethod.val = value; + if (value === 'jwt') { + delete validityPerField['consumer_secret']; + } else { + delete validityPerField['permitted_user']; + delete validityPerField['private_key']; + } + isValid.val = Object.values(validityPerField).every(v => v); + }, layout: 'inline', }), Input({ @@ -1427,10 +1436,7 @@ const SalesforceData360Form = ( validityPerField['permitted_user'] = state.valid; isValid.val = Object.values(validityPerField).every(v => v); }, - validators: [ - requiredIf(() => authMethod.val === 'jwt'), - maxLength(250), - ], + validators: [required, maxLength(250)], }), FileInput({ name: 'private_key', @@ -1456,7 +1462,7 @@ const SalesforceData360Form = ( isValid.val = Object.values(validityPerField).every(v => v); }, validators: [ - requiredIf(() => authMethod.val === 'client_credentials' && !originalConnection?.connection_id || !originalConnection?.private_key), + requiredIf(() => !originalConnection?.connection_id || !originalConnection?.private_key), sizeLimit(200 * 1024 * 1024), ], }), @@ -1476,7 +1482,7 @@ const SalesforceData360Form = ( isValid.val = Object.values(validityPerField).every(v => v); }, validators: [ - requiredIf(() => authMethod.val === 'jwt' && !originalConnection?.connection_id || !originalConnection?.project_pw_encrypted), + requiredIf(() => !originalConnection?.connection_id || !originalConnection?.project_pw_encrypted), ], }); }, From a4037d8b918db928e32dc98e4d48c54aa4368079 Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Wed, 20 May 2026 08:27:28 -0400 Subject: [PATCH 29/58] refactor(mcp): apply TG-1071 review feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Drop hyphens: "Test-run schedule" → "Test run schedule" (heading, docstring) - Rename ``ScheduleType`` values to ``profiling_run`` / ``test_run`` to align with the rendered display labels - Recent-executions section: "Recent executions" → "Recent runs", column "Job execution ID" → "Job ID", footer "5 most recent executions" → "5 most recent runs" — matches the "run" terminology used elsewhere - Cached wrapper ``get_monitor_schedule(monitor_suite_id)`` in ``ui/views/monitors_dashboard.py``; restores per-render-rerun caching for the two view-side call sites that previously relied on ``JobSchedule.get``'s removed ``@st.cache_data`` Co-Authored-By: Claude Opus 4.7 (1M context) --- testgen/mcp/tools/schedules.py | 22 +++++++++++----------- testgen/ui/services/query_cache.py | 11 +++++++++++ testgen/ui/views/monitors_dashboard.py | 12 +++--------- tests/unit/mcp/test_tools_schedules.py | 8 ++++---- 4 files changed, 29 insertions(+), 24 deletions(-) diff --git a/testgen/mcp/tools/schedules.py b/testgen/mcp/tools/schedules.py index 73975aac..31b534b1 100644 --- a/testgen/mcp/tools/schedules.py +++ b/testgen/mcp/tools/schedules.py @@ -32,13 +32,13 @@ class ScheduleType(StrEnum): - profiling = "profiling" - test_execution = "test_execution" + profiling_run = "profiling_run" + test_run = "test_run" _SCHEDULE_TYPE_TO_JOB_KEY: dict[ScheduleType, JobKey] = { - ScheduleType.profiling: JobKey.run_profile, - ScheduleType.test_execution: JobKey.run_tests, + ScheduleType.profiling_run: JobKey.run_profile, + ScheduleType.test_run: JobKey.run_tests, } @@ -244,7 +244,7 @@ def create_test_run_schedule( get_current_session().flush() doc = MdDoc() - doc.heading(1, f"Test-run schedule created for `{suite.test_suite}`") + doc.heading(1, f"Test run schedule created for `{suite.test_suite}`") _render_schedule(doc, sched, linked_name=suite.test_suite) return doc.render() @@ -333,11 +333,11 @@ def list_schedules( limit: int = 20, page: int = 1, ) -> str: - """List schedules for a project — profiling and test-run schedules. + """List schedules for a project — profiling and test run schedules. Args: project_code: Project to scope to, e.g. from ``list_projects``. - schedule_type: Optional filter — ``profiling`` or ``test_execution``. + schedule_type: Optional filter — ``profiling_run`` or ``test_run``. limit: Max rows per page. Defaults to 20. page: 1-indexed page number. Defaults to 1. """ @@ -415,9 +415,9 @@ def get_schedule(schedule_id: str) -> str: .limit(5) ).all() - doc.heading(2, "Recent executions") + doc.heading(2, "Recent runs") if not history: - doc.text("_No executions yet._") + doc.text("_No runs yet._") return doc.render() rows: list[list[object]] = [] @@ -430,12 +430,12 @@ def get_schedule(schedule_id: str) -> str: format_run_duration(je.started_at, je.completed_at), ]) doc.table( - ["Job execution ID", "Status", "Started", "Completed", "Duration"], + ["Job ID", "Status", "Started", "Completed", "Duration"], rows, code=[0], ) doc.text( - "_Showing the 5 most recent executions._ " + "_Showing the 5 most recent runs._ " "Use `list_test_runs` or `list_profiling_runs` for full history." ) return doc.render() diff --git a/testgen/ui/services/query_cache.py b/testgen/ui/services/query_cache.py index 7dca4918..c90afe0b 100644 --- a/testgen/ui/services/query_cache.py +++ b/testgen/ui/services/query_cache.py @@ -16,6 +16,7 @@ from testgen.common.models.profiling_run import ProfilingRun, ProfilingRunSummary from testgen.common.models.project import Project, ProjectSummary from testgen.common.models.project_membership import ProjectMembership +from testgen.common.models.scheduler import RUN_MONITORS_JOB_KEY, JobSchedule from testgen.common.models.table_group import TableGroup, TableGroupStats, TableGroupSummary from testgen.common.models.test_definition import TestType, TestTypeSummary from testgen.common.models.test_run import TestRun, TestRunSummary @@ -126,3 +127,13 @@ def get_profiling_run_summaries( page_size: int = 20, ) -> tuple[list[ProfilingRunSummary], int]: return ProfilingRun.select_summary(project_code, table_group_id, page=page, page_size=page_size) + + +# -- JobSchedule -------------------------------------------------------------- + +@st.cache_data(show_spinner=False) +def get_monitor_schedule(monitor_suite_id: str | UUID) -> JobSchedule | None: + return JobSchedule.get( + JobSchedule.key == RUN_MONITORS_JOB_KEY, + JobSchedule.kwargs["test_suite_id"].astext == str(monitor_suite_id), + ) diff --git a/testgen/ui/views/monitors_dashboard.py b/testgen/ui/views/monitors_dashboard.py index b7eaaad4..1b6becb9 100644 --- a/testgen/ui/views/monitors_dashboard.py +++ b/testgen/ui/views/monitors_dashboard.py @@ -25,7 +25,7 @@ from testgen.ui.navigation.router import Router from testgen.ui.queries.profiling_queries import get_tables_by_table_group from testgen.ui.services.database_service import execute_db_query, fetch_all_from_db, fetch_one_from_db -from testgen.ui.services.query_cache import get_project_summary, get_test_type_summaries +from testgen.ui.services.query_cache import get_monitor_schedule, get_project_summary, get_test_type_summaries from testgen.ui.services.rerun_service import safe_rerun from testgen.ui.session import session, temp_value from testgen.ui.utils import dict_from_kv, get_cron_sample_handler @@ -108,10 +108,7 @@ def render( if monitor_suite_id: with st.spinner(text="Loading data ..."): - monitor_schedule = JobSchedule.get( - JobSchedule.key == RUN_MONITORS_JOB_KEY, - JobSchedule.kwargs["test_suite_id"].astext == str(monitor_suite_id), - ) + monitor_schedule = get_monitor_schedule(monitor_suite_id) anomaly_type_filter = [t for t in anomaly_type_filter.split(",") if t in ANOMALY_TYPE_FILTERS] if anomaly_type_filter else None if sort_field and sort_field not in ALLOWED_SORT_FIELDS: @@ -758,10 +755,7 @@ def on_close_trends(_payload=None): predictions = {} if len(definitions) > 0: test_suite = TestSuite.get(table_group.monitor_test_suite_id) - monitor_schedule = JobSchedule.get( - JobSchedule.key == RUN_MONITORS_JOB_KEY, - JobSchedule.kwargs["test_suite_id"].astext == str(table_group.monitor_test_suite_id), - ) + monitor_schedule = get_monitor_schedule(table_group.monitor_test_suite_id) monitor_lookback = test_suite.monitor_lookback predict_sensitivity = test_suite.predict_sensitivity or PredictSensitivity.medium diff --git a/tests/unit/mcp/test_tools_schedules.py b/tests/unit/mcp/test_tools_schedules.py index 98d17d55..4278ca58 100644 --- a/tests/unit/mcp/test_tools_schedules.py +++ b/tests/unit/mcp/test_tools_schedules.py @@ -143,7 +143,7 @@ def test_create_test_run_schedule_happy_path(mock_resolve_suite, mock_sched_cls, cron_tz="UTC", ) - assert "Test-run schedule created" in result + assert "Test run schedule created" in result assert "suite_a" in result saved.save.assert_called_once() @@ -207,7 +207,7 @@ def test_list_schedules_type_filter_maps_to_job_key(mock_sched_cls, mock_linked, from testgen.mcp.tools.schedules import list_schedules - list_schedules(project_code="demo", schedule_type="profiling") + list_schedules(project_code="demo", schedule_type="profiling_run") call_kwargs = mock_sched_cls.list_for_project.call_args assert call_kwargs.kwargs["key_filter"] == [JobKey.run_profile.value] @@ -248,7 +248,7 @@ def test_get_schedule_no_executions(mock_resolve, mock_linked, mock_session, db_ result = get_schedule(schedule_id=str(sched.id)) assert "orders_tg" in result - assert "No executions yet" in result + assert "No runs yet" in result @patch("testgen.mcp.tools.schedules.get_current_session") @@ -272,7 +272,7 @@ def test_get_schedule_with_executions(mock_resolve, mock_linked, mock_session, d from testgen.mcp.tools.schedules import get_schedule result = get_schedule(schedule_id=str(sched.id)) - assert "Recent executions" in result + assert "Recent runs" in result assert str(je.id) in result From 476a7d0072951e7a7aaeeb7e89b9c8c2b4b00cb1 Mon Sep 17 00:00:00 2001 From: testgen-ci-bot Date: Wed, 20 May 2026 18:09:23 +0000 Subject: [PATCH 30/58] ci: bump base image to v16 --- deploy/testgen.dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deploy/testgen.dockerfile b/deploy/testgen.dockerfile index 743f9edf..800cfa4f 100644 --- a/deploy/testgen.dockerfile +++ b/deploy/testgen.dockerfile @@ -1,4 +1,4 @@ -ARG TESTGEN_BASE_LABEL=v15 +ARG TESTGEN_BASE_LABEL=v16 FROM datakitchen/dataops-testgen-base:${TESTGEN_BASE_LABEL} AS release-image From 5f4b3b87193098a493254f85e5b06ae974958a3b Mon Sep 17 00:00:00 2001 From: Astor Date: Wed, 20 May 2026 17:38:49 -0300 Subject: [PATCH 31/58] feat(TG-1001): exclude monitor suites from all queries Add `is_monitor IS NOT TRUE` filter to all test suite queries in the data catalog UI, CLI get-entities commands, and observability exporter so monitor suites are never surfaced to end users. Also fixes a pre-existing GROUP BY bug in get_test_suite_list.sql where the missing GROUP BY caused a PostgreSQL grouping error. Co-Authored-By: Claude Sonnet 4.6 --- testgen/commands/run_observability_exporter.py | 1 + testgen/template/get_entities/get_test_generation_list.sql | 1 + testgen/template/get_entities/get_test_info.sql | 1 + testgen/template/get_entities/get_test_run_list.sql | 1 + testgen/template/get_entities/get_test_suite.sql | 3 ++- testgen/template/get_entities/get_test_suite_list.sql | 2 ++ testgen/ui/views/data_catalog.py | 4 ++-- 7 files changed, 10 insertions(+), 3 deletions(-) diff --git a/testgen/commands/run_observability_exporter.py b/testgen/commands/run_observability_exporter.py index 71179e9d..e0339026 100644 --- a/testgen/commands/run_observability_exporter.py +++ b/testgen/commands/run_observability_exporter.py @@ -318,6 +318,7 @@ def run_observability_exporter(project_code, test_suite): test_suites = TestSuite.select_minimal_where( TestSuite.project_code == project_code, TestSuite.test_suite == test_suite, + TestSuite.is_monitor.isnot(True), ) qty_of_exported_events = export_test_results(test_suites[0].id) click.echo(f"{qty_of_exported_events} events have been exported.") diff --git a/testgen/template/get_entities/get_test_generation_list.sql b/testgen/template/get_entities/get_test_generation_list.sql index b4322693..c14b1fa8 100644 --- a/testgen/template/get_entities/get_test_generation_list.sql +++ b/testgen/template/get_entities/get_test_generation_list.sql @@ -14,6 +14,7 @@ FROM test_definitions td JOIN test_suites ts ON td.test_suite_id = ts.id WHERE ts.project_code = :PROJECT_CODE AND ts.test_suite = :TEST_SUITE + AND ts.is_monitor IS NOT TRUE AND td.last_auto_gen_date IS NOT NULL GROUP BY ts.id, td.last_auto_gen_date, td.profiling_as_of_date, td.lock_refresh ORDER BY td.last_auto_gen_date desc; diff --git a/testgen/template/get_entities/get_test_info.sql b/testgen/template/get_entities/get_test_info.sql index 142ddc63..d07b4c78 100644 --- a/testgen/template/get_entities/get_test_info.sql +++ b/testgen/template/get_entities/get_test_info.sql @@ -39,6 +39,7 @@ INNER JOIN test_types tt ON td.test_type = tt.test_type INNER JOIN test_suites ts ON td.test_suite_id = ts.id WHERE ts.project_code = :PROJECT_CODE AND ts.test_suite = :TEST_SUITE + AND ts.is_monitor IS NOT TRUE ORDER BY td.schema_name, td.table_name, td.column_name, diff --git a/testgen/template/get_entities/get_test_run_list.sql b/testgen/template/get_entities/get_test_run_list.sql index 14079499..50f9ecc7 100644 --- a/testgen/template/get_entities/get_test_run_list.sql +++ b/testgen/template/get_entities/get_test_run_list.sql @@ -17,6 +17,7 @@ INNER JOIN test_results r ON tr.id = r.test_run_id INNER JOIN test_suites ts ON tr.test_suite_id = ts.id WHERE ts.project_code = :PROJECT_CODE AND ts.test_suite = :TEST_SUITE + AND ts.is_monitor IS NOT TRUE GROUP BY tr.id, ts.project_code, ts.test_suite, diff --git a/testgen/template/get_entities/get_test_suite.sql b/testgen/template/get_entities/get_test_suite.sql index fdbd9638..8d0a6c22 100644 --- a/testgen/template/get_entities/get_test_suite.sql +++ b/testgen/template/get_entities/get_test_suite.sql @@ -8,4 +8,5 @@ SELECT component_type FROM test_suites WHERE project_code = :PROJECT_CODE -AND test_suite = :TEST_SUITE; +AND test_suite = :TEST_SUITE +AND is_monitor IS NOT TRUE; diff --git a/testgen/template/get_entities/get_test_suite_list.sql b/testgen/template/get_entities/get_test_suite_list.sql index 4ba63e1f..1fe6e363 100644 --- a/testgen/template/get_entities/get_test_suite_list.sql +++ b/testgen/template/get_entities/get_test_suite_list.sql @@ -8,4 +8,6 @@ LEFT JOIN test_runs tr ON tr.test_suite_id = ts.id WHERE ts.project_code = :PROJECT_CODE + AND ts.is_monitor IS NOT TRUE + GROUP BY ts.id, ts.project_code, ts.test_suite, ts.connection_id, ts.test_suite_description ORDER BY ts.test_suite; diff --git a/testgen/ui/views/data_catalog.py b/testgen/ui/views/data_catalog.py index 0f09c5bf..10a6128d 100644 --- a/testgen/ui/views/data_catalog.py +++ b/testgen/ui/views/data_catalog.py @@ -774,7 +774,7 @@ def get_latest_test_issues(table_group_id: str, table_name: str, column_name: st test_results.test_type = test_types.test_type ) WHERE test_suites.table_groups_id = :table_group_id - AND test_suites.is_monitor = false + AND test_suites.is_monitor IS NOT TRUE AND table_name = :table_name {"AND column_names = :column_name" if column_name else ""} AND result_status NOT IN ('Passed', 'Log') @@ -809,7 +809,7 @@ def get_related_test_suites(table_group_id: str, table_name: str, column_name: s test_definitions.test_suite_id = test_suites.id ) WHERE test_suites.table_groups_id = :table_group_id - AND test_suites.is_monitor = false + AND test_suites.is_monitor IS NOT TRUE AND table_name = :table_name {"AND column_name = :column_name" if column_name else ""} GROUP BY test_suites.id From 183805cb97fb7984fc6c10a0e375186f6e4ca8ca Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Thu, 21 May 2026 16:11:02 -0400 Subject: [PATCH 32/58] fix(TG-1080): cross-flavor template fixes for QUERY-style tests Align lookup_query and template SQL so the new QUERY-test functional tests (Dupe_Rows, Aggregate_Balance variants, Combo_Match, Distribution_Shift, Timeframe_Combo_*) pass across all supported flavors: - BigQuery Distribution_Shift: switch to a FULL JOIN form so categories present in only one side still contribute to the JS divergence. - Databricks Timeframe_Combo_Gain: add the missing lookup_query so the source-data drilldown works on Databricks. - Quote the MATCH-side identifiers in referential lookup_query templates so case-sensitive flavors (Snowflake, Salesforce Data 360) resolve them correctly. Co-Authored-By: Claude Opus 4.7 --- .../test_types_Aggregate_Balance.yaml | 18 +++++++-------- .../test_types_Aggregate_Balance_Percent.yaml | 18 +++++++-------- .../test_types_Aggregate_Balance_Range.yaml | 18 +++++++-------- .../test_types_Aggregate_Minimum.yaml | 18 +++++++-------- .../test_types_Combo_Match.yaml | 18 +++++++-------- .../test_types_Distribution_Shift.yaml | 22 ++++++++++++++----- .../test_types_Timeframe_Combo_Gain.yaml | 20 +++++++++++++++++ 7 files changed, 82 insertions(+), 50 deletions(-) diff --git a/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance.yaml b/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance.yaml index cb98ab7e..f38b89f4 100644 --- a/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance.yaml @@ -56,7 +56,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -83,7 +83,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -108,7 +108,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -132,7 +132,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -157,7 +157,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -182,7 +182,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -207,7 +207,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -232,7 +232,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -257,7 +257,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a diff --git a/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Percent.yaml b/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Percent.yaml index 3b0d81d7..1415731d 100644 --- a/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Percent.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Percent.yaml @@ -56,7 +56,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -85,7 +85,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -112,7 +112,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -138,7 +138,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -165,7 +165,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -192,7 +192,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -219,7 +219,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -246,7 +246,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -273,7 +273,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a diff --git a/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Range.yaml b/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Range.yaml index 2fc50146..84f20602 100644 --- a/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Range.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Aggregate_Balance_Range.yaml @@ -56,7 +56,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -85,7 +85,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -112,7 +112,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -138,7 +138,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -165,7 +165,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -192,7 +192,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -219,7 +219,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -246,7 +246,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -273,7 +273,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a diff --git a/testgen/template/dbsetup_test_types/test_types_Aggregate_Minimum.yaml b/testgen/template/dbsetup_test_types/test_types_Aggregate_Minimum.yaml index cd35e549..425a72e2 100644 --- a/testgen/template/dbsetup_test_types/test_types_Aggregate_Minimum.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Aggregate_Minimum.yaml @@ -56,7 +56,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL AS total, {MATCH_COLUMN_NAMES} AS match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -83,7 +83,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -108,7 +108,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -132,7 +132,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -157,7 +157,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -182,7 +182,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -207,7 +207,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -232,7 +232,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a @@ -257,7 +257,7 @@ test_types: {HAVING_CONDITION} UNION ALL SELECT {MATCH_GROUPBY_NAMES}, NULL as total, {MATCH_COLUMN_NAMES} as match_total - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} ) a diff --git a/testgen/template/dbsetup_test_types/test_types_Combo_Match.yaml b/testgen/template/dbsetup_test_types/test_types_Combo_Match.yaml index 3b027325..f4016136 100644 --- a/testgen/template/dbsetup_test_types/test_types_Combo_Match.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Combo_Match.yaml @@ -54,7 +54,7 @@ test_types: {HAVING_CONDITION} EXCEPT DISTINCT SELECT {MATCH_GROUPBY_NAMES} - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -76,7 +76,7 @@ test_types: {HAVING_CONDITION} EXCEPT SELECT {MATCH_GROUPBY_NAMES} - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM `{MATCH_SCHEMA_NAME}`.`{MATCH_TABLE_NAME}` WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -98,7 +98,7 @@ test_types: {HAVING_CONDITION} EXCEPT SELECT {MATCH_GROUPBY_NAMES} - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -119,7 +119,7 @@ test_types: {HAVING_CONDITION} EXCEPT SELECT {MATCH_GROUPBY_NAMES} - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -141,7 +141,7 @@ test_types: {HAVING_CONDITION} EXCEPT SELECT {MATCH_GROUPBY_NAMES} - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -163,7 +163,7 @@ test_types: {HAVING_CONDITION} EXCEPT SELECT {MATCH_GROUPBY_NAMES} - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -185,7 +185,7 @@ test_types: {HAVING_CONDITION} EXCEPT SELECT {MATCH_GROUPBY_NAMES} - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -207,7 +207,7 @@ test_types: {HAVING_CONDITION} MINUS SELECT {MATCH_GROUPBY_NAMES} - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} @@ -229,7 +229,7 @@ test_types: {HAVING_CONDITION} EXCEPT SELECT {MATCH_GROUPBY_NAMES} - FROM {MATCH_SCHEMA_NAME}.{MATCH_TABLE_NAME} + FROM "{MATCH_SCHEMA_NAME}"."{MATCH_TABLE_NAME}" WHERE {MATCH_SUBSET_CONDITION} GROUP BY {MATCH_GROUPBY_NAMES} {MATCH_HAVING_CONDITION} diff --git a/testgen/template/dbsetup_test_types/test_types_Distribution_Shift.yaml b/testgen/template/dbsetup_test_types/test_types_Distribution_Shift.yaml index 7a3b2361..666bf095 100644 --- a/testgen/template/dbsetup_test_types/test_types_Distribution_Shift.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Distribution_Shift.yaml @@ -49,13 +49,25 @@ test_types: lookup_query: |- WITH latest_ver AS ( SELECT {CONCAT_COLUMNS} AS category, - CAST(COUNT(*) AS FLOAT64) / SUM(COUNT(*)) OVER() AS pct_of_total - FROM `{TARGET_SCHEMA}`.`{TABLE_NAME}` v1 + CAST(COUNT(*) AS FLOAT64) / CAST(SUM(COUNT(*)) OVER () AS FLOAT64) AS pct_of_total + FROM `{TARGET_SCHEMA}.{TABLE_NAME}` v1 WHERE {SUBSET_CONDITION} - GROUP BY {CONCAT_COLUMNS} + GROUP BY {COLUMN_NAME_NO_QUOTES} + ), + older_ver AS ( + SELECT {CONCAT_MATCH_GROUPBY} AS category, + CAST(COUNT(*) AS FLOAT64) / CAST(SUM(COUNT(*)) OVER () AS FLOAT64) AS pct_of_total + FROM `{MATCH_SCHEMA_NAME}.{TABLE_NAME}` v2 + WHERE {MATCH_SUBSET_CONDITION} + GROUP BY {MATCH_GROUPBY_NAMES} ) - SELECT * - FROM latest_ver + SELECT COALESCE(l.category, o.category) AS category, + o.pct_of_total AS old_pct, + l.pct_of_total AS new_pct + FROM latest_ver l + FULL JOIN older_ver o + ON l.category = o.category + ORDER BY COALESCE(l.category, o.category) LIMIT {LIMIT}; error_type: Test Results - id: '1336' diff --git a/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Gain.yaml b/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Gain.yaml index 13bdb85c..34329e26 100644 --- a/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Gain.yaml +++ b/testgen/template/dbsetup_test_types/test_types_Timeframe_Combo_Gain.yaml @@ -60,6 +60,26 @@ test_types: GROUP BY {COLUMN_NAME_NO_QUOTES} LIMIT {LIMIT}; error_type: Test Results + - id: '1396' + test_id: '1508' + test_type: Timeframe_Combo_Gain + sql_flavor: databricks + lookup_type: null + lookup_query: |- + SELECT {COLUMN_NAME_NO_QUOTES} + FROM `{TARGET_SCHEMA}`.`{TABLE_NAME}` + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= (SELECT MAX({WINDOW_DATE_COLUMN}) FROM `{TARGET_SCHEMA}`.`{TABLE_NAME}`) - 2 * {WINDOW_DAYS} + AND {WINDOW_DATE_COLUMN} < (SELECT MAX({WINDOW_DATE_COLUMN}) FROM `{TARGET_SCHEMA}`.`{TABLE_NAME}`) - {WINDOW_DAYS} + GROUP BY {COLUMN_NAME_NO_QUOTES} + EXCEPT + SELECT {COLUMN_NAME_NO_QUOTES} + FROM `{TARGET_SCHEMA}`.`{TABLE_NAME}` + WHERE {SUBSET_CONDITION} + AND {WINDOW_DATE_COLUMN} >= (SELECT MAX({WINDOW_DATE_COLUMN}) FROM `{TARGET_SCHEMA}`.`{TABLE_NAME}`) - {WINDOW_DAYS} + GROUP BY {COLUMN_NAME_NO_QUOTES} + LIMIT {LIMIT}; + error_type: Test Results - id: '1263' test_id: '1508' test_type: Timeframe_Combo_Gain From 167a7b11ee8b09a29a33e7641d72f20fd88cb9dc Mon Sep 17 00:00:00 2001 From: Luis Date: Tue, 12 May 2026 17:15:59 -0400 Subject: [PATCH 33/58] refactor(mcp): update inventory tool to display scorecards --- testgen/common/models/scores.py | 59 +++++++ testgen/mcp/services/inventory_service.py | 50 +++++- .../common/models/test_score_definition.py | 97 ++++++++++++ tests/unit/mcp/test_inventory_service.py | 146 ++++++++++++++++++ 4 files changed, 351 insertions(+), 1 deletion(-) diff --git a/testgen/common/models/scores.py b/testgen/common/models/scores.py index cc693244..85f3ff94 100644 --- a/testgen/common/models/scores.py +++ b/testgen/common/models/scores.py @@ -140,6 +140,65 @@ def get(cls, id_: str) -> Self | None: definition = db_session.scalars(query).first() return definition + @classmethod + def list_with_table_group_targets( + cls, + project_code: str, + ) -> list[tuple[UUID, str, list[str]]]: + """Return all scorecards in the project, each paired with the list of + `table_groups_name` values their criteria reference. + + Walks both root filters (`criteria.filters`) and the `next_filter` chain + via a recursive CTE. A scorecard with zero name filters has an empty + list; multiple are returned in chain order. + + Single query. Does NOT eagerly load the criteria/filter ORM objects — + the caller gets only (id, name, target names). Used by the MCP + inventory tool to surface scorecard IDs under each table group. + """ + # Seed: root filters joined through criteria for the project's definitions. + seed = ( + select( + ScoreDefinitionCriteria.definition_id.label("definition_id"), + ScoreDefinitionFilter.field.label("field"), + ScoreDefinitionFilter.value.label("value"), + ScoreDefinitionFilter.next_filter_id.label("next_filter_id"), + ) + .select_from(ScoreDefinitionCriteria) + .join(ScoreDefinitionFilter, ScoreDefinitionFilter.criteria_id == ScoreDefinitionCriteria.id) + .join(ScoreDefinition, ScoreDefinition.id == ScoreDefinitionCriteria.definition_id) + .where(ScoreDefinition.project_code == project_code) + .cte("filter_walk", recursive=True) + ) + # Recursive step: follow next_filter_id to walk the chain. + chain = aliased(ScoreDefinitionFilter) + filter_walk = seed.union_all( + select( + seed.c.definition_id, + chain.field, + chain.value, + chain.next_filter_id, + ) + .select_from(seed) + .join(chain, chain.id == seed.c.next_filter_id) + ) + + tg_names = ( + func.array_agg(filter_walk.c.value) + .filter(filter_walk.c.field == "table_groups_name") + .label("tg_names") + ) + query = ( + select(ScoreDefinition.id, ScoreDefinition.name, tg_names) + .select_from(ScoreDefinition) + .outerjoin(filter_walk, filter_walk.c.definition_id == ScoreDefinition.id) + .where(ScoreDefinition.project_code == project_code) + .group_by(ScoreDefinition.id, ScoreDefinition.name) + .order_by(ScoreDefinition.name) + ) + rows = get_current_session().execute(query).all() + return [(row.id, row.name, list(row.tg_names) if row.tg_names else []) for row in rows] + @classmethod def all( cls, diff --git a/testgen/mcp/services/inventory_service.py b/testgen/mcp/services/inventory_service.py index a20aef31..d744c89c 100644 --- a/testgen/mcp/services/inventory_service.py +++ b/testgen/mcp/services/inventory_service.py @@ -5,6 +5,7 @@ from testgen.common.models import get_current_session from testgen.common.models.connection import Connection from testgen.common.models.project import Project +from testgen.common.models.scores import ScoreDefinition from testgen.common.models.table_group import TableGroup, TableGroupSummary from testgen.common.models.test_suite import TestSuite from testgen.utils import friendly_score, score @@ -95,10 +96,12 @@ def get_inventory( view_codes_set = set(view_project_codes) profiling_by_tg: dict[UUID, TableGroupSummary] = {} + scorecards_by_project: dict[str, tuple[dict[str, list[tuple[str, str]]], list[tuple[str, str]]]] = {} for code in view_codes_set: summaries, _ = TableGroup.select_summary(code) for summary in summaries: profiling_by_tg[summary.id] = summary + scorecards_by_project[code] = _scorecards_by_table_group(code) # Format as Markdown lines = ["# Data Inventory\n"] @@ -125,6 +128,10 @@ def get_inventory( for group_id, group in conn["groups"].items(): summary = profiling_by_tg.get(group_id) if can_view else None + tg_scorecards: list[tuple[str, str]] = [] + if can_view: + by_tg, _ = scorecards_by_project[project_code] + tg_scorecards = by_tg.get(group["name"], []) if compact_groups or not can_view: line = ( @@ -133,6 +140,8 @@ def get_inventory( ) if summary: line += f", {_profiling_summary_fragment(summary)}" + if tg_scorecards: + line += f", scorecards: {len(tg_scorecards)}" lines.append(line) continue @@ -143,26 +152,65 @@ def get_inventory( if summary: lines.append(f"_{_profiling_summary_fragment(summary)}_\n") + if tg_scorecards: + lines.append("**Scorecards:**") + for sid, name in tg_scorecards: + lines.append(f"- **{name}** (id: `{sid}`)") + lines.append("") + if not group["suites"]: lines.append("_No test suites._\n") continue + lines.append("**Test Suites:**") for suite in group["suites"]: lines.append(f"- **{suite['name']}** (id: `{suite['id']}`)") lines.append("") lines.append("") + if can_view: + _, multi = scorecards_by_project.get(project_code, ({}, [])) + if multi: + lines.append("### Scorecards spanning multiple table groups\n") + for sid, name in multi: + lines.append(f"- **{name}** (id: `{sid}`)") + lines.append("") + lines.append( "---\n" "Use `list_tables(table_group_id='...')` to see tables in a group.\n" "Use `list_test_suites(project_code='...')` for suite details and latest run stats.\n" - "Use `list_profiling_summaries(table_group_id='...')` for the quality score rollup and hygiene issue counts." + "Use `list_profiling_summaries(table_group_id='...')` for the quality score rollup and hygiene issue counts.\n" + "Use `get_scorecard(scorecard_id='...')` for the score breakdown and category detail." ) return "\n".join(lines) +def _scorecards_by_table_group( + project_code: str, +) -> tuple[dict[str, list[tuple[str, str]]], list[tuple[str, str]]]: + """Index scorecards in a project by the table groups they target by name. + + Returns (by_tg_name, multi_or_none): + - by_tg_name[tg_name] = list of (scorecard_id_str, scorecard_name) for + scorecards that declare a `table_groups_name = tg_name` filter. + - multi_or_none lists scorecards whose name-filter count is not exactly 1 + (zero filters → project-wide; multiple → spans TGs by name). Such + scorecards appear under every named TG AND in this list. + """ + by_tg: dict[str, list[tuple[str, str]]] = {} + multi_or_none: list[tuple[str, str]] = [] + for sc_id, sc_name, tg_names in ScoreDefinition.list_with_table_group_targets(project_code): + entry = (str(sc_id), sc_name) + for tg_name in tg_names: + by_tg.setdefault(tg_name, []).append(entry) + if len(tg_names) != 1: + multi_or_none.append(entry) + return by_tg, multi_or_none + + def _profiling_summary_fragment(summary: TableGroupSummary) -> str: """Compact one-liner of profiling metadata for a table group.""" if not summary.latest_profile_id: diff --git a/tests/unit/common/models/test_score_definition.py b/tests/unit/common/models/test_score_definition.py index 1080e37e..3b678e09 100644 --- a/tests/unit/common/models/test_score_definition.py +++ b/tests/unit/common/models/test_score_definition.py @@ -6,6 +6,7 @@ """ from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest @@ -16,6 +17,9 @@ ScoreDefinitionFilter, ) + +from testgen.common.models.scores import ScoreDefinition + pytestmark = pytest.mark.unit @@ -111,3 +115,96 @@ def test_categories_query_uses_column_template_for_column_category(): # Column-grouped template aggregates by a placeholder substituted into the SELECT. assert "business_domain" in categories_sql assert CDE_FILTER_FRAGMENT in categories_sql +# --- list_with_table_group_targets --- + + +def _row(definition_id, name, tg_names): + """Simulate a row returned by the recursive-CTE aggregate query.""" + row = MagicMock() + row.id = definition_id + row.name = name + row.tg_names = tg_names + return row + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_with_table_group_targets_single_name_filter(mock_session_fn): + """A scorecard with one table_groups_name filter yields (id, name, [tg_name]).""" + def_id = uuid4() + mock_result = MagicMock() + mock_result.all.return_value = [_row(def_id, "orders-sc", ["orders"])] + mock_session_fn.return_value.execute.return_value = mock_result + + out = ScoreDefinition.list_with_table_group_targets("proj") + + assert out == [(def_id, "orders-sc", ["orders"])] + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_with_table_group_targets_multiple_name_filters(mock_session_fn): + """A scorecard with multiple table_groups_name filters yields all names.""" + def_id = uuid4() + mock_result = MagicMock() + mock_result.all.return_value = [_row(def_id, "multi-sc", ["orders", "customers"])] + mock_session_fn.return_value.execute.return_value = mock_result + + out = ScoreDefinition.list_with_table_group_targets("proj") + + assert out == [(def_id, "multi-sc", ["orders", "customers"])] + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_with_table_group_targets_no_name_filter(mock_session_fn): + """A scorecard with no table_groups_name filter yields an empty list of targets.""" + def_id = uuid4() + mock_result = MagicMock() + # Postgres array_agg with FILTER returns NULL when no rows match — the method + # must normalize this to []. + mock_result.all.return_value = [_row(def_id, "metadata-only-sc", None)] + mock_session_fn.return_value.execute.return_value = mock_result + + out = ScoreDefinition.list_with_table_group_targets("proj") + + assert out == [(def_id, "metadata-only-sc", [])] + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_with_table_group_targets_filters_by_project_code(mock_session_fn): + """The query filters on project_code via the WHERE clause.""" + mock_result = MagicMock() + mock_result.all.return_value = [] + mock_session_fn.return_value.execute.return_value = mock_result + + ScoreDefinition.list_with_table_group_targets("my-project") + + args, _ = mock_session_fn.return_value.execute.call_args + compiled = args[0].compile(compile_kwargs={"literal_binds": True}) + sql = str(compiled) + assert "project_code" in sql + assert "'my-project'" in sql + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_with_table_group_targets_uses_recursive_cte_on_filter_chain(mock_session_fn): + """The query SQL walks score_definition_filters via next_filter_id (recursive CTE).""" + mock_result = MagicMock() + mock_result.all.return_value = [] + mock_session_fn.return_value.execute.return_value = mock_result + + ScoreDefinition.list_with_table_group_targets("proj") + + args, _ = mock_session_fn.return_value.execute.call_args + sql = str(args[0].compile(compile_kwargs={"literal_binds": True})) + assert "RECURSIVE" in sql.upper() + assert "next_filter_id" in sql + assert "table_groups_name" in sql + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_with_table_group_targets_empty_project(mock_session_fn): + """When the project has no scorecards, returns an empty list.""" + mock_result = MagicMock() + mock_result.all.return_value = [] + mock_session_fn.return_value.execute.return_value = mock_result + + assert ScoreDefinition.list_with_table_group_targets("proj") == [] diff --git a/tests/unit/mcp/test_inventory_service.py b/tests/unit/mcp/test_inventory_service.py index ed0ba1bd..51329145 100644 --- a/tests/unit/mcp/test_inventory_service.py +++ b/tests/unit/mcp/test_inventory_service.py @@ -17,6 +17,15 @@ def table_group_select_summary_mock(): yield mock +@pytest.fixture(autouse=True) +def scorecards_by_project_mock(): + with patch( + "testgen.mcp.services.inventory_service.ScoreDefinition.list_with_table_group_targets" + ) as mock: + mock.return_value = [] + yield mock + + def _make_row(project_code="demo", project_name="Demo", connection_id=1, connection_name="main", table_group_id=None, table_groups_name="core", table_group_schema="public", test_suite_id=None, test_suite="Quality"): @@ -142,6 +151,7 @@ def test_get_inventory_with_view_shows_all_details(mock_select, session_mock): result = get_inventory(project_codes=["demo"], view_project_codes=["demo"]) assert "main" in result # connection name shown + assert "**Test Suites:**" in result assert "Visible Suite" in result assert str(suite_id) in result assert "requires `view` permission" not in result @@ -222,3 +232,139 @@ def test_get_inventory_never_profiled_fragment( assert "not profiled yet" in result assert "hygiene issues" not in result assert "Score" not in result + + +# ---------------------------------------------------------------------- +# Scorecard rendering +# ---------------------------------------------------------------------- + + +@patch("testgen.mcp.services.inventory_service.select") +def test_get_inventory_lists_single_tg_scorecard_under_tg( + mock_select, session_mock, scorecards_by_project_mock, +): + """A scorecard targeting one TG by name renders as a bullet under that TG.""" + tg_id = uuid4() + sc_id = uuid4() + session_mock.execute.return_value.all.return_value = [ + _make_row(table_group_id=tg_id, table_groups_name="core"), + ] + scorecards_by_project_mock.return_value = [(sc_id, "Core Scorecard", ["core"])] + + from testgen.mcp.services.inventory_service import get_inventory + result = get_inventory(project_codes=["demo"], view_project_codes=["demo"]) + + assert "**Scorecards:**" in result + assert f"- **Core Scorecard** (id: `{sc_id}`)" in result + # No spanning section when every scorecard targets exactly one TG. + assert "spanning multiple table groups" not in result + + +@patch("testgen.mcp.services.inventory_service.select") +def test_get_inventory_multi_tg_scorecard_appears_under_each_named_tg_and_spanning( + mock_select, session_mock, scorecards_by_project_mock, +): + """A scorecard targeting two TGs appears under each TG AND in the spanning section.""" + tg_a, tg_b = uuid4(), uuid4() + sc_id = uuid4() + session_mock.execute.return_value.all.return_value = [ + _make_row(table_group_id=tg_a, table_groups_name="orders", test_suite_id=uuid4()), + _make_row(table_group_id=tg_b, table_groups_name="customers", test_suite_id=uuid4()), + ] + scorecards_by_project_mock.return_value = [(sc_id, "Cross", ["orders", "customers"])] + + from testgen.mcp.services.inventory_service import get_inventory + result = get_inventory(project_codes=["demo"], view_project_codes=["demo"]) + + assert result.count(f"- **Cross** (id: `{sc_id}`)") == 3 + assert "### Scorecards spanning multiple table groups" in result + + +@patch("testgen.mcp.services.inventory_service.select") +def test_get_inventory_no_name_filter_scorecard_in_spanning_section_only( + mock_select, session_mock, scorecards_by_project_mock, +): + """A scorecard with no table_groups_name filter only appears in the spanning section.""" + tg_id = uuid4() + sc_id = uuid4() + session_mock.execute.return_value.all.return_value = [_make_row(table_group_id=tg_id)] + scorecards_by_project_mock.return_value = [(sc_id, "Metadata Only", [])] + + from testgen.mcp.services.inventory_service import get_inventory + result = get_inventory(project_codes=["demo"], view_project_codes=["demo"]) + + assert "### Scorecards spanning multiple table groups" in result + assert f"- **Metadata Only** (id: `{sc_id}`)" in result + # The TG block should not have a Scorecards: line. + assert "**Scorecards:**" not in result + + +@patch("testgen.mcp.services.inventory_service.select") +def test_get_inventory_compact_mode_emits_scorecards_count_no_ids( + mock_select, session_mock, scorecards_by_project_mock, +): + """Compact mode (>50 groups) appends 'scorecards: N' to the one-liner; no IDs.""" + rows = [ + _make_row( + table_group_id=uuid4(), + table_groups_name=f"Group_{i}", + test_suite=f"Suite_{i}", + test_suite_id=uuid4(), + ) + for i in range(55) + ] + session_mock.execute.return_value.all.return_value = rows + sc_id = uuid4() + scorecards_by_project_mock.return_value = [(sc_id, "G0 Scorecard", ["Group_0"])] + + from testgen.mcp.services.inventory_service import get_inventory + result = get_inventory(project_codes=["demo"], view_project_codes=["demo"]) + + assert "scorecards: 1" in result + assert str(sc_id) not in result # no IDs in compact mode + + +@patch("testgen.mcp.services.inventory_service.select") +def test_get_inventory_catalog_only_project_hides_scorecards( + mock_select, session_mock, scorecards_by_project_mock, +): + """Without view permission, the ORM lookup is skipped and no scorecard text renders.""" + tg_id = uuid4() + session_mock.execute.return_value.all.return_value = [_make_row(table_group_id=tg_id)] + scorecards_by_project_mock.return_value = [(uuid4(), "Hidden", ["core"])] + + from testgen.mcp.services.inventory_service import get_inventory + result = get_inventory(project_codes=["demo"], view_project_codes=[]) + + scorecards_by_project_mock.assert_not_called() + assert "Scorecards" not in result + assert "Hidden" not in result + + +@patch("testgen.mcp.services.inventory_service.select") +def test_get_inventory_footer_includes_get_scorecard_hint( + mock_select, session_mock, scorecards_by_project_mock, +): + """Footer mentions get_scorecard for discoverability.""" + session_mock.execute.return_value.all.return_value = [_make_row()] + + from testgen.mcp.services.inventory_service import get_inventory + result = get_inventory(project_codes=["demo"], view_project_codes=["demo"]) + + assert "get_scorecard(scorecard_id=" in result + + +@patch("testgen.mcp.services.inventory_service.select") +def test_get_inventory_no_scorecards_omits_scorecards_line( + mock_select, session_mock, scorecards_by_project_mock, +): + """When no scorecards target a TG, the Scorecards line is omitted entirely.""" + tg_id = uuid4() + session_mock.execute.return_value.all.return_value = [_make_row(table_group_id=tg_id)] + scorecards_by_project_mock.return_value = [] + + from testgen.mcp.services.inventory_service import get_inventory + result = get_inventory(project_codes=["demo"], view_project_codes=["demo"]) + + assert "**Scorecards:**" not in result + assert "spanning multiple table groups" not in result From cec8098b894b8d16b9dfead62566ab51685e1283 Mon Sep 17 00:00:00 2001 From: Luis Date: Wed, 13 May 2026 12:40:22 -0400 Subject: [PATCH 34/58] feat(mcp): add new tool get_quality_scores Implements get_quality_scores: a flexible rollup over TestGen's scoring engine with project_code / table_group_id scoping, twelve group_by dimensions, OR-within-field / AND-across-field filters, score_type (Combined/CDE), and opt-in include_issue_ct / include_impact. Builds a transient ScoreDefinition and reuses as_score_card + get_score_card_breakdown; uses user-facing labels throughout, validates filter values before SQL substitution, gates on project view permission, and inherits monitor-suite exclusion from the scoring views. --- testgen/common/models/scores.py | 40 +- testgen/mcp/server.py | 2 + testgen/mcp/tools/common.py | 113 +++ testgen/mcp/tools/quality_scores.py | 347 ++++++++ .../common/models/test_score_definition.py | 87 +- tests/unit/mcp/test_tools_common.py | 131 ++++ tests/unit/mcp/test_tools_quality_scores.py | 742 ++++++++++++++++++ 7 files changed, 1458 insertions(+), 4 deletions(-) create mode 100644 testgen/mcp/tools/quality_scores.py create mode 100644 tests/unit/mcp/test_tools_quality_scores.py diff --git a/testgen/common/models/scores.py b/testgen/common/models/scores.py index 85f3ff94..803d7f7d 100644 --- a/testgen/common/models/scores.py +++ b/testgen/common/models/scores.py @@ -12,7 +12,22 @@ from typing import Literal, Self, TypedDict from uuid import UUID, uuid4 -from sqlalchemy import Boolean, Column, DateTime, Enum, Float, ForeignKey, Integer, String, delete, func, select, text +from sqlalchemy import ( + Boolean, + Column, + DateTime, + Enum, + Float, + ForeignKey, + Integer, + String, + column, + delete, + func, + select, + table, + text, +) from sqlalchemy.dialects import postgresql from sqlalchemy.orm import aliased, attributes, relationship @@ -535,6 +550,29 @@ def recalculate_scores_history(self) -> None: self.history = list(current_history.values()) + def get_overall_issue_ct(self) -> int: + """Sum of hygiene + test issue counts under this definition's filters. + + Reuses the same filter machinery as `as_score_card` so the rolled-up + count matches the score that call returns. + """ + if not self.criteria.has_filters(): + return 0 + + where_clause = text(" AND ".join(self._get_raw_query_filters())) + session = get_current_session() + + def _sum_issue_ct(view_name: str) -> int: + view = table(view_name, column("issue_ct")) + return int(session.execute( + select(func.coalesce(func.sum(view.c.issue_ct), 0)).where(where_clause) + ).scalar() or 0) + + return ( + _sum_issue_ct("v_dq_profile_scoring_latest_by_column") + + _sum_issue_ct("v_dq_test_scoring_latest_by_column") + ) + def _get_raw_query_filters(self, cde_only: bool = False, prefix: str | None = None) -> list[str]: extra_filters = [ f"{prefix or ''}project_code = '{self.project_code}'" diff --git a/testgen/mcp/server.py b/testgen/mcp/server.py index 838ceb26..1e77724d 100644 --- a/testgen/mcp/server.py +++ b/testgen/mcp/server.py @@ -167,6 +167,7 @@ def build_mcp_server( list_profiling_summaries, search_columns, ) + from testgen.mcp.tools.quality_scores import get_quality_scores from testgen.mcp.tools.reference import ( column_profile_fields_resource, get_test_type, @@ -278,6 +279,7 @@ def safe_prompt(fn): safe_tool(get_schedule) safe_tool(update_schedule) safe_tool(delete_schedule) + safe_tool(get_quality_scores) # Resources safe_resource("testgen://test-types", test_types_resource) diff --git a/testgen/mcp/tools/common.py b/testgen/mcp/tools/common.py index acc382b6..7e5ed394 100644 --- a/testgen/mcp/tools/common.py +++ b/testgen/mcp/tools/common.py @@ -97,6 +97,119 @@ def parse_quality_dimension(value: str) -> QualityDimension: raise MCPUserError(f"Invalid quality_dimension `{value}`. Valid values: {valid}") from err +class ScoreGroupBy(StrEnum): + """User-facing values accepted for the ``group_by`` argument on quality-score rollups.""" + + QUALITY_DIMENSION = "Quality Dimension" + IMPACT_DIMENSION = "Impact Dimension" + SEMANTIC_DATA_TYPE = "Semantic Data Type" + TABLE_GROUP = "Table Group" + DATA_LOCATION = "Data Location" + DATA_SOURCE = "Data Source" + SOURCE_SYSTEM = "Source System" + SOURCE_PROCESS = "Source Process" + BUSINESS_DOMAIN = "Business Domain" + STAKEHOLDER_GROUP = "Stakeholder Group" + TRANSFORM_LEVEL = "Transform Level" + DATA_PRODUCT = "Data Product" + + +# Translates the user-facing label to the internal DB column name used by +# ``ScoreCategory`` and the criteria filter list. +SCORE_GROUP_BY_TO_COLUMN: dict[ScoreGroupBy, str] = { + ScoreGroupBy.QUALITY_DIMENSION: "dq_dimension", + ScoreGroupBy.IMPACT_DIMENSION: "impact_dimension", + ScoreGroupBy.SEMANTIC_DATA_TYPE: "semantic_data_type", + ScoreGroupBy.TABLE_GROUP: "table_groups_name", + ScoreGroupBy.DATA_LOCATION: "data_location", + ScoreGroupBy.DATA_SOURCE: "data_source", + ScoreGroupBy.SOURCE_SYSTEM: "source_system", + ScoreGroupBy.SOURCE_PROCESS: "source_process", + ScoreGroupBy.BUSINESS_DOMAIN: "business_domain", + ScoreGroupBy.STAKEHOLDER_GROUP: "stakeholder_group", + ScoreGroupBy.TRANSFORM_LEVEL: "transform_level", + ScoreGroupBy.DATA_PRODUCT: "data_product", +} + + +class ScoreFilterField(StrEnum): + """User-facing values accepted for ``filters[].field`` on quality-score rollups. + + Same shape as ``ScoreGroupBy`` minus the two dimension values — Quality + Dimension and Impact Dimension are valid as ``group_by``, not as filter + fields. The duplication is deliberate: each argument has its own enum so + the valid-value set for each is read off one StrEnum. + """ + + SEMANTIC_DATA_TYPE = "Semantic Data Type" + TABLE_GROUP = "Table Group" + DATA_LOCATION = "Data Location" + DATA_SOURCE = "Data Source" + SOURCE_SYSTEM = "Source System" + SOURCE_PROCESS = "Source Process" + BUSINESS_DOMAIN = "Business Domain" + STAKEHOLDER_GROUP = "Stakeholder Group" + TRANSFORM_LEVEL = "Transform Level" + DATA_PRODUCT = "Data Product" + + +SCORE_FILTER_FIELD_TO_COLUMN: dict[ScoreFilterField, str] = { + ScoreFilterField.SEMANTIC_DATA_TYPE: "semantic_data_type", + ScoreFilterField.TABLE_GROUP: "table_groups_name", + ScoreFilterField.DATA_LOCATION: "data_location", + ScoreFilterField.DATA_SOURCE: "data_source", + ScoreFilterField.SOURCE_SYSTEM: "source_system", + ScoreFilterField.SOURCE_PROCESS: "source_process", + ScoreFilterField.BUSINESS_DOMAIN: "business_domain", + ScoreFilterField.STAKEHOLDER_GROUP: "stakeholder_group", + ScoreFilterField.TRANSFORM_LEVEL: "transform_level", + ScoreFilterField.DATA_PRODUCT: "data_product", +} + + +class ScoreType(StrEnum): + """User-facing values accepted for the ``score_type`` argument.""" + + COMBINED = "Combined" + CDE = "CDE" + + +# Translates to the internal sentinel consumed by ``ScoreDefinition.total_score`` +# / ``cde_score`` flag logic. +SCORE_TYPE_TO_INTERNAL: dict[ScoreType, str] = { + ScoreType.COMBINED: "total", + ScoreType.CDE: "cde", +} + + +def parse_score_group_by(value: str) -> ScoreGroupBy: + try: + return ScoreGroupBy(value) + except ValueError as err: + valid = ", ".join(g.value for g in ScoreGroupBy) + raise MCPUserError(f"Invalid group_by `{value}`. Valid values: {valid}") from err + + +def parse_score_filter_field(value: str) -> ScoreFilterField: + try: + return ScoreFilterField(value) + except ValueError as err: + if value in {ScoreGroupBy.QUALITY_DIMENSION.value, ScoreGroupBy.IMPACT_DIMENSION.value}: + raise MCPUserError( + f"`{value}` is not a valid filter field — use group_by='{value}' instead" + ) from err + valid = ", ".join(f.value for f in ScoreFilterField) + raise MCPUserError(f"Invalid filter field `{value}`. Valid values: {valid}") from err + + +def parse_score_type(value: str) -> ScoreType: + try: + return ScoreType(value) + except ValueError as err: + valid = ", ".join(s.value for s in ScoreType) + raise MCPUserError(f"Invalid score_type `{value}`. Valid values: {valid}") from err + + # Maps user-facing run-status labels to underlying ``JobStatus`` values. Transient states # (Starting/Canceling) are excluded because they're sub-second and noisy as filters. # ``Pending`` collapses PENDING+CLAIMED; ``Canceled`` collapses CANCEL_REQUESTED+CANCELED. diff --git a/testgen/mcp/tools/quality_scores.py b/testgen/mcp/tools/quality_scores.py new file mode 100644 index 00000000..c91c90f3 --- /dev/null +++ b/testgen/mcp/tools/quality_scores.py @@ -0,0 +1,347 @@ +from testgen.common.models import with_database_session +from testgen.common.models.scores import ( + ScoreCategory, + ScoreDefinition, + ScoreDefinitionCriteria, +) +from testgen.common.models.table_group import TableGroup +from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError +from testgen.mcp.permissions import get_project_permissions, mcp_permission +from testgen.mcp.tools.common import ( + SCORE_FILTER_FIELD_TO_COLUMN, + SCORE_GROUP_BY_TO_COLUMN, + DocGroup, + ScoreGroupBy, + ScoreType, + parse_score_filter_field, + parse_score_group_by, + parse_score_type, + resolve_table_group, +) +from testgen.mcp.tools.markdown import MdDoc +from testgen.utils import friendly_score, friendly_score_impact + +_DOC_GROUP = DocGroup.DISCOVER + +_VALUE_MAX_LEN = 256 +_VALUE_FORBIDDEN_CHARS = frozenset("'\";\\\x00") + +# Defensive Python-side cap on grouped output. The category-scores SQL doesn't +# LIMIT, and most valid group_by values produce small bounded result sets +# (≤ ~15 dimensions/domains), but pathological metadata could blow this up. +_ROW_CAP = 100 + +_COMBINED_LABEL = "Combined Score" +_CDE_LABEL = "CDE Score" + + +@with_database_session +@mcp_permission("view") +def get_quality_scores( + *, + project_code: str | None = None, + table_group_id: str | None = None, + group_by: str | None = None, + score_type: str | None = None, + filters: list[dict] | None = None, + include_issue_ct: bool = False, + include_impact: bool = False, +) -> str: + """Quality-score rollup with optional grouping and filtering. + + Args: + project_code: Scope to a single project. Omit to roll across every + project the caller can view. + table_group_id: Scope to a single table group, e.g. from + ``get_data_inventory``. Mutually exclusive with ``project_code``. + group_by: One of 'Quality Dimension', 'Impact Dimension', + 'Semantic Data Type', 'Table Group', 'Data Location', + 'Data Source', 'Source System', 'Source Process', + 'Business Domain', 'Stakeholder Group', 'Transform Level', + 'Data Product'. Omit for the unfiltered overall score. + score_type: Narrows which score(s) are reported. Omit (default) to + show both Combined and CDE; pass 'Combined' to show only the + Combined Score, or 'CDE' to show only the CDE Score. + filters: List of {"field": str, "value": str} pairs. Same-field + filters OR together; different fields AND together. Valid fields + are the same as ``group_by`` except 'Quality Dimension' and + 'Impact Dimension', which are valid as ``group_by`` only. Filter + values must not contain quotes, semicolons, or backslashes. + include_issue_ct: When True, include the count of contributing issues + (hygiene + test failures). + include_impact: When True, include the per-category impact on the + overall score (the percentage contribution to total quality + loss). Only affects grouped output. + """ + perms = get_project_permissions() + + if project_code is not None and table_group_id is not None: + raise MCPUserError( + "Pass either `project_code` or `table_group_id`, not both." + ) + + parsed_score_type: ScoreType | None = ( + parse_score_type(score_type) if score_type is not None else None + ) + parsed_group_by: ScoreGroupBy | None = ( + parse_score_group_by(group_by) if group_by is not None else None + ) + + user_filters = _validate_filters(filters) + + if table_group_id is not None: + table_group = resolve_table_group(table_group_id) + scope_codes = [table_group.project_code] + table_group_name = table_group.table_groups_name + elif project_code is not None: + perms.verify_access( + project_code, + not_found=MCPResourceNotAccessible("Project", project_code), + ) + scope_codes = [project_code] + table_group_name = None + else: + scope_codes = list(perms.allowed_codes) + table_group_name = None + + doc = MdDoc() + doc.heading(1, "Quality Scores") + + if table_group_id is not None: + doc.text(f"Scope: Table Group `{table_group_name}` (project `{scope_codes[0]}`).") + elif project_code is not None: + doc.text(f"Scope: Project `{scope_codes[0]}`.") + else: + doc.text(f"Scope: all accessible projects ({len(scope_codes)}).") + + cross_project = project_code is None and table_group_id is None and len(scope_codes) > 1 + + for code in scope_codes: + _render_one_scope( + doc, + project_code=code, + table_group_name=table_group_name, + group_by=parsed_group_by, + score_type=parsed_score_type, + user_filters=user_filters, + include_issue_ct=include_issue_ct, + include_impact=include_impact, + heading=code if cross_project else None, + ) + + return doc.render() + + +def _validate_filters(filters: list[dict] | None) -> list[dict]: + """Validate filter dicts and translate ``field`` from user labels to internal DB columns.""" + if not filters: + return [] + errors: list[str] = [] + cleaned: list[dict] = [] + for i, entry in enumerate(filters): + if not isinstance(entry, dict): + errors.append(f"entry {i}: must be a dict with `field` and `value`") + continue + field = entry.get("field") + value = entry.get("value") + if not field: + errors.append(f"entry {i}: missing `field`") + continue + if value is None or value == "": + errors.append(f"entry {i} ({field!r}): empty value") + continue + try: + parsed_field = parse_score_filter_field(field) + except MCPUserError as err: + errors.append(f"entry {i}: {err}") + continue + if not isinstance(value, str): + errors.append(f"entry {i} ({field!r}): value must be a string") + continue + if len(value) > _VALUE_MAX_LEN: + errors.append( + f"entry {i} ({field!r}): value too long ({len(value)} > {_VALUE_MAX_LEN})" + ) + continue + bad_chars = sorted(set(value) & _VALUE_FORBIDDEN_CHARS) + if bad_chars: + errors.append( + f"entry {i} ({field!r}): value contains forbidden characters {bad_chars}" + ) + continue + cleaned.append({"field": SCORE_FILTER_FIELD_TO_COLUMN[parsed_field], "value": value}) + if errors: + raise MCPUserError("Invalid filters: " + "; ".join(errors)) + return cleaned + + +def _build_definition( + *, + project_code: str, + table_group_name: str | None, + group_by: ScoreGroupBy | None, + score_type: ScoreType | None, + user_filters: list[dict], +) -> ScoreDefinition: + definition = ScoreDefinition() + definition.project_code = project_code + definition.name = "__mcp_get_quality_scores__" + # score_type=None enables both; a specific value enables only that one. + # `as_score_card` derives `cde_only_categories = cde_score and not + # total_score` — so flag combinations decide whether the category SQL + # filters by `critical_data_element = true`. + definition.total_score = score_type is None or score_type is ScoreType.COMBINED + definition.cde_score = score_type is None or score_type is ScoreType.CDE + definition.category = ( + ScoreCategory(SCORE_GROUP_BY_TO_COLUMN[group_by]) if group_by is not None else None + ) + + filters: list[dict] = list(user_filters) + if table_group_name is not None: + filters.append({"field": "table_groups_name", "value": table_group_name}) + elif not filters: + # `as_score_card` short-circuits when criteria has no filters + # (scores.py:292). Mirror the score-explorer UI's pattern: a + # scorecard always carries at least one filter, typically + # `table_groups_name`. For the unfiltered project-wide case, + # enumerate every table group in the project so the criteria + # still narrows by project_code (added by `_get_raw_query_filters`) + # and covers all table groups. + tg_names = [ + tg.table_groups_name + for tg in TableGroup.select_minimal_where( + TableGroup.project_code == project_code, + ) + ] + if tg_names: + filters.extend( + {"field": "table_groups_name", "value": name} for name in tg_names + ) + + definition.criteria = ScoreDefinitionCriteria.from_filters( + filters, group_by_field=True, + ) + return definition + + +def _render_one_scope( + doc: MdDoc, + *, + project_code: str, + table_group_name: str | None, + group_by: ScoreGroupBy | None, + score_type: ScoreType | None, + user_filters: list[dict], + include_issue_ct: bool, + include_impact: bool, + heading: str | None, +) -> None: + if heading is not None: + doc.heading(2, f"Project `{heading}`") + + definition = _build_definition( + project_code=project_code, + table_group_name=table_group_name, + group_by=group_by, + score_type=score_type, + user_filters=user_filters, + ) + + show_combined = score_type is None or score_type is ScoreType.COMBINED + show_cde = score_type is None or score_type is ScoreType.CDE + + card = definition.as_score_card() + if show_combined: + doc.field(_COMBINED_LABEL, friendly_score(card.get("score"))) + if show_cde: + doc.field(_CDE_LABEL, friendly_score(card.get("cde_score"))) + + if include_issue_ct and group_by is None: + doc.field("Issue Count", definition.get_overall_issue_ct()) + + if group_by is None: + return + + group_by_column = SCORE_GROUP_BY_TO_COLUMN[group_by] + + # Per-category data — score, impact, issue_ct — comes from + # get_score_card_breakdown. One call per enabled score type, since each + # filters different rows (Combined includes all data points; CDE filters + # to critical_data_element=true). + combined_rows: dict[str, dict] = {} + cde_rows: dict[str, dict] = {} + if show_combined: + for r in definition.get_score_card_breakdown("score", group_by_column): + label = r.get(group_by_column) + if label is not None: + combined_rows[label] = r + if show_cde: + for r in definition.get_score_card_breakdown("cde_score", group_by_column): + label = r.get(group_by_column) + if label is not None: + cde_rows[label] = r + + all_labels = set(combined_rows) | set(cde_rows) + if not all_labels: + doc.text("_No category data._") + return + + # Worst score first. Sort by primary column (Combined if shown, else CDE). + def _sort_key(label: str) -> float: + primary = combined_rows if show_combined else cde_rows + score = (primary.get(label) or {}).get("score") + return score if score is not None else 1.0 + + sorted_labels = sorted(all_labels, key=_sort_key) + total_rows = len(sorted_labels) + capped = sorted_labels[:_ROW_CAP] + + both_shown = show_combined and show_cde + combined_issue_header = "Issue Count (Combined)" if both_shown else "Issue Count" + cde_issue_header = "Issue Count (CDE)" if both_shown else "Issue Count" + + headers: list[str] = [group_by.value] + if show_combined: + headers.append(_COMBINED_LABEL) + if include_impact: + headers.append("Impact on Combined Score") + if include_issue_ct: + headers.append(combined_issue_header) + if show_cde: + headers.append(_CDE_LABEL) + if include_impact: + headers.append("Impact on CDE Score") + if include_issue_ct: + headers.append(cde_issue_header) + + md_rows: list[list[object]] = [] + for label in capped: + cells: list[object] = [label] + c_row = combined_rows.get(label) or {} + d_row = cde_rows.get(label) or {} + if show_combined: + cells.append(friendly_score(c_row.get("score"))) + if include_impact: + cells.append(_format_impact(c_row.get("impact"))) + if include_issue_ct: + cells.append(c_row.get("issue_ct") if c_row else None) + if show_cde: + cells.append(friendly_score(d_row.get("score"))) + if include_impact: + cells.append(_format_impact(d_row.get("impact"))) + if include_issue_ct: + cells.append(d_row.get("issue_ct") if d_row else None) + md_rows.append(cells) + doc.table(headers, md_rows) + + if total_rows > _ROW_CAP: + doc.text(f"_Showing top {_ROW_CAP} of {total_rows} rows by lowest score._") + + +def _format_impact(value: float | None) -> str | None: + # Pass None through so MdDoc renders an em-dash for missing data + # (friendly_score_impact returns the literal "-" for None/0, which + # mismatches the score column's em-dash treatment). + if value is None: + return None + return friendly_score_impact(value) diff --git a/tests/unit/common/models/test_score_definition.py b/tests/unit/common/models/test_score_definition.py index 3b678e09..b0acb712 100644 --- a/tests/unit/common/models/test_score_definition.py +++ b/tests/unit/common/models/test_score_definition.py @@ -17,9 +17,6 @@ ScoreDefinitionFilter, ) - -from testgen.common.models.scores import ScoreDefinition - pytestmark = pytest.mark.unit @@ -208,3 +205,87 @@ def test_list_with_table_group_targets_empty_project(mock_session_fn): mock_session_fn.return_value.execute.return_value = mock_result assert ScoreDefinition.list_with_table_group_targets("proj") == [] + + +# --- get_overall_issue_ct --- + + +def _definition_with_filter(project_code="demo", field="business_domain", value="Finance"): + """Build a transient ScoreDefinition with one filter.""" + definition = ScoreDefinition() + definition.project_code = project_code + definition.name = "test" + definition.total_score = True + definition.cde_score = False + definition.criteria = ScoreDefinitionCriteria( + operand="AND", + group_by_field=True, + filters=[ScoreDefinitionFilter(field=field, value=value)], + ) + return definition + + +@patch("testgen.common.models.scores.get_current_session") +def test_get_overall_issue_ct_sums_profile_and_test(mock_session_fn): + """Returns the sum of profile + test issue_ct from the two scoring views.""" + definition = _definition_with_filter() + # Two execute() calls; first returns profile sum, second returns test sum. + mock_session_fn.return_value.execute.side_effect = [ + MagicMock(scalar=MagicMock(return_value=7)), + MagicMock(scalar=MagicMock(return_value=3)), + ] + + assert definition.get_overall_issue_ct() == 10 + + +@patch("testgen.common.models.scores.get_current_session") +def test_get_overall_issue_ct_queries_both_views(mock_session_fn): + """Issues two queries — one against the profile view, one against the test view.""" + definition = _definition_with_filter() + mock_session_fn.return_value.execute.side_effect = [ + MagicMock(scalar=MagicMock(return_value=0)), + MagicMock(scalar=MagicMock(return_value=0)), + ] + + definition.get_overall_issue_ct() + + calls = mock_session_fn.return_value.execute.call_args_list + assert len(calls) == 2 + sql_1 = str(calls[0].args[0]) + sql_2 = str(calls[1].args[0]) + assert "v_dq_profile_scoring_latest_by_column" in sql_1 + assert "v_dq_test_scoring_latest_by_column" in sql_2 + # Both queries must use the same filters as as_score_card (project_code + criteria). + for sql in (sql_1, sql_2): + assert "project_code = 'demo'" in sql + assert "business_domain = 'Finance'" in sql + + +@patch("testgen.common.models.scores.get_current_session") +def test_get_overall_issue_ct_handles_null_scalars(mock_session_fn): + """A NULL sum (no matching rows) is treated as 0, not None.""" + definition = _definition_with_filter() + mock_session_fn.return_value.execute.side_effect = [ + MagicMock(scalar=MagicMock(return_value=None)), + MagicMock(scalar=MagicMock(return_value=None)), + ] + + assert definition.get_overall_issue_ct() == 0 + + +def test_get_overall_issue_ct_no_filters_returns_zero(): + """When the definition has no filters, return 0 without hitting the DB.""" + definition = ScoreDefinition() + definition.project_code = "demo" + definition.name = "test" + definition.total_score = True + definition.cde_score = False + definition.criteria = ScoreDefinitionCriteria( + operand="AND", + group_by_field=True, + filters=[], + ) + + with patch("testgen.common.models.scores.get_current_session") as mock_session_fn: + assert definition.get_overall_issue_ct() == 0 + mock_session_fn.return_value.execute.assert_not_called() diff --git a/tests/unit/mcp/test_tools_common.py b/tests/unit/mcp/test_tools_common.py index 193a1c83..9416cc16 100644 --- a/tests/unit/mcp/test_tools_common.py +++ b/tests/unit/mcp/test_tools_common.py @@ -7,6 +7,12 @@ from testgen.common.models.test_result import TestResultStatus from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.tools.common import ( + SCORE_FILTER_FIELD_TO_COLUMN, + SCORE_GROUP_BY_TO_COLUMN, + SCORE_TYPE_TO_INTERNAL, + ScoreFilterField, + ScoreGroupBy, + ScoreType, format_disposition, parse_disposition, parse_impact_dimension, @@ -14,6 +20,9 @@ parse_pii_risk_list, parse_quality_dimension, parse_result_status, + parse_score_filter_field, + parse_score_group_by, + parse_score_type, parse_uuid, resolve_issue_type, resolve_profiling_run, @@ -464,3 +473,125 @@ def test_build_ilike_pattern_escapes_underscores_even_with_explicit_percent(): from testgen.mcp.tools.common import build_ilike_pattern # The `_` escape is unconditional — explicit `%` doesn't suppress it. assert build_ilike_pattern("user_%") == r"user\_%" + + +# --- parse_score_group_by --- + + +@pytest.mark.parametrize("member", list(ScoreGroupBy)) +def test_parse_score_group_by_user_labels(member): + assert parse_score_group_by(member.value) is member + + +def test_parse_score_group_by_label_maps_to_internal_column(): + """The enum value is the user-facing label; the mapping translates to the + internal DB column name used downstream (``ScoreCategory``, the criteria + filter list).""" + assert SCORE_GROUP_BY_TO_COLUMN[ScoreGroupBy.QUALITY_DIMENSION] == "dq_dimension" + assert SCORE_GROUP_BY_TO_COLUMN[ScoreGroupBy.TABLE_GROUP] == "table_groups_name" + assert SCORE_GROUP_BY_TO_COLUMN[ScoreGroupBy.BUSINESS_DOMAIN] == "business_domain" + + +@pytest.mark.parametrize( + "internal", + ["dq_dimension", "impact_dimension", "business_domain", "table_groups_name"], +) +def test_parse_score_group_by_rejects_internal_column_name(internal): + """Old internal vocabulary must be rejected — the tool now speaks user labels only.""" + with pytest.raises(MCPUserError, match="Invalid group_by") as exc_info: + parse_score_group_by(internal) + msg = str(exc_info.value) + # Error must point users at the new user-facing vocabulary. + assert "Quality Dimension" in msg + assert "Business Domain" in msg + + +def test_parse_score_group_by_invalid_lists_valid_values(): + with pytest.raises(MCPUserError, match="Valid values:") as exc_info: + parse_score_group_by("Made Up") + msg = str(exc_info.value) + for member in ScoreGroupBy: + assert member.value in msg + + +# --- parse_score_filter_field --- + + +@pytest.mark.parametrize("member", list(ScoreFilterField)) +def test_parse_score_filter_field_user_labels(member): + assert parse_score_filter_field(member.value) is member + + +def test_parse_score_filter_field_label_maps_to_internal_column(): + assert SCORE_FILTER_FIELD_TO_COLUMN[ScoreFilterField.BUSINESS_DOMAIN] == "business_domain" + assert SCORE_FILTER_FIELD_TO_COLUMN[ScoreFilterField.TABLE_GROUP] == "table_groups_name" + + +def test_parse_score_filter_field_does_not_include_dimensions(): + """Quality Dimension / Impact Dimension are valid only as group_by, not as filter fields.""" + values = {m.value for m in ScoreFilterField} + assert "Quality Dimension" not in values + assert "Impact Dimension" not in values + + +@pytest.mark.parametrize("label", ["Quality Dimension", "Impact Dimension"]) +def test_parse_score_filter_field_rejects_dimension_with_hint(label): + """Passing a dimension as filter.field hints at group_by= usage instead.""" + with pytest.raises(MCPUserError, match=f"`{label}`") as exc_info: + parse_score_filter_field(label) + msg = str(exc_info.value) + assert "group_by" in msg + assert label in msg + + +@pytest.mark.parametrize( + "internal", ["business_domain", "data_source", "table_groups_name"], +) +def test_parse_score_filter_field_rejects_internal_column_name(internal): + with pytest.raises(MCPUserError, match="Invalid filter field") as exc_info: + parse_score_filter_field(internal) + msg = str(exc_info.value) + assert "Business Domain" in msg + + +def test_parse_score_filter_field_invalid_lists_valid_values(): + with pytest.raises(MCPUserError, match="Valid values:") as exc_info: + parse_score_filter_field("Made Up") + msg = str(exc_info.value) + for member in ScoreFilterField: + assert member.value in msg + + +# --- parse_score_type --- + + +@pytest.mark.parametrize( + "label,expected_member,expected_internal", + [ + ("Combined", ScoreType.COMBINED, "total"), + ("CDE", ScoreType.CDE, "cde"), + ], +) +def test_parse_score_type_user_labels(label, expected_member, expected_internal): + member = parse_score_type(label) + assert member is expected_member + assert SCORE_TYPE_TO_INTERNAL[member] == expected_internal + + +@pytest.mark.parametrize("internal", ["total", "cde", "combined"]) +def test_parse_score_type_rejects_internal_or_wrong_case(internal): + """The old internal vocabulary (``total``/``cde`` lowercase) must no longer + be accepted on input.""" + with pytest.raises(MCPUserError, match="Invalid score_type") as exc_info: + parse_score_type(internal) + msg = str(exc_info.value) + assert "Combined" in msg + assert "CDE" in msg + + +def test_parse_score_type_invalid_lists_valid_values(): + with pytest.raises(MCPUserError, match="Valid values:") as exc_info: + parse_score_type("BadType") + msg = str(exc_info.value) + for member in ScoreType: + assert member.value in msg diff --git a/tests/unit/mcp/test_tools_quality_scores.py b/tests/unit/mcp/test_tools_quality_scores.py new file mode 100644 index 00000000..10aae95c --- /dev/null +++ b/tests/unit/mcp/test_tools_quality_scores.py @@ -0,0 +1,742 @@ +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError +from testgen.mcp.permissions import ProjectPermissions + +pytestmark = pytest.mark.unit + + +# --- Helpers --- + + +def _score_card( + score=0.9, + cde_score=0.8, + profiling_score=0.95, + testing_score=0.85, + categories=None, +): + """Default ScoreCard dict returned by ScoreDefinition.as_score_card().""" + return { + "id": uuid4(), + "project_code": "demo", + "name": "test", + "score": score, + "cde_score": cde_score, + "profiling_score": profiling_score, + "testing_score": testing_score, + "categories": categories or [], + "history": [], + "definition": None, + } + + +def _patch_perms(allowed=("demo",), memberships=None): + """Return a patch context manager that injects a ProjectPermissions with given access.""" + memberships = memberships or {code: "role_a" for code in allowed} + return patch( + "testgen.mcp.permissions._compute_project_permissions", + return_value=ProjectPermissions( + memberships=memberships, permission="view", username="test_user", + ), + ) + + +# --- Argument validation --- + + +def test_mutually_exclusive_scope_args_rejected(db_session_mock): + from testgen.mcp.tools.quality_scores import get_quality_scores + + with _patch_perms(), pytest.raises(MCPUserError, match="project_code.*table_group_id"): + get_quality_scores(project_code="demo", table_group_id=str(uuid4())) + + +def test_invalid_group_by_rejected(db_session_mock): + from testgen.mcp.tools.quality_scores import get_quality_scores + + with _patch_perms(), pytest.raises(MCPUserError, match="Invalid group_by") as exc_info: + get_quality_scores(project_code="demo", group_by="invented_field") + msg = str(exc_info.value) + # Error message must speak the user-facing vocabulary. + assert "Quality Dimension" in msg + + +@pytest.mark.parametrize("group_by", ["column_name", "table_name", "dq_dimension"]) +def test_internal_group_by_value_rejected(group_by, db_session_mock): + """Old internal column names (row-level or column-form) are no longer accepted.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + with _patch_perms(), pytest.raises(MCPUserError, match="Invalid group_by"): + get_quality_scores(project_code="demo", group_by=group_by) + + +def test_invalid_score_type_rejected(db_session_mock): + from testgen.mcp.tools.quality_scores import get_quality_scores + + with _patch_perms(), pytest.raises(MCPUserError, match="Invalid score_type") as exc_info: + get_quality_scores(project_code="demo", score_type="garbage") + msg = str(exc_info.value) + assert "Combined" in msg + assert "CDE" in msg + + +@pytest.mark.parametrize("internal", ["total", "cde"]) +def test_internal_score_type_rejected(internal, db_session_mock): + """``total``/``cde`` were the old internal codes — inputs now use ``Combined``/``CDE``.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + with _patch_perms(), pytest.raises(MCPUserError, match="Invalid score_type"): + get_quality_scores(project_code="demo", score_type=internal) + + +def test_invalid_filter_field_rejected(db_session_mock): + from testgen.mcp.tools.quality_scores import get_quality_scores + + with _patch_perms(), pytest.raises(MCPUserError, match="Invalid filter field"): + get_quality_scores( + project_code="demo", + filters=[{"field": "not_a_field", "value": "x"}], + ) + + +def test_internal_filter_field_rejected(db_session_mock): + """Old internal column name as filter field is no longer accepted.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + with _patch_perms(), pytest.raises(MCPUserError, match="Invalid filter field"): + get_quality_scores( + project_code="demo", + filters=[{"field": "business_domain", "value": "Finance"}], + ) + + +def test_quality_dimension_rejected_as_filter_field(db_session_mock): + """Quality Dimension is a group_by, not a filter field — must reject with a hint.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + with _patch_perms(), pytest.raises(MCPUserError, match="Quality Dimension") as exc_info: + get_quality_scores( + project_code="demo", + filters=[{"field": "Quality Dimension", "value": "Accuracy"}], + ) + assert "group_by" in str(exc_info.value) + + +def test_impact_dimension_rejected_as_filter_field(db_session_mock): + from testgen.mcp.tools.quality_scores import get_quality_scores + + with _patch_perms(), pytest.raises(MCPUserError, match="Impact Dimension") as exc_info: + get_quality_scores( + project_code="demo", + filters=[{"field": "Impact Dimension", "value": "Workflow"}], + ) + assert "group_by" in str(exc_info.value) + + +def test_filter_value_with_forbidden_chars_rejected(db_session_mock): + """SQL-injection probe — values with single quotes or semicolons must be rejected.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + with _patch_perms(), pytest.raises(MCPUserError, match="forbidden"): + get_quality_scores( + project_code="demo", + filters=[{"field": "Business Domain", "value": "O';DROP TABLE"}], + ) + + +def test_filter_value_oversize_rejected(db_session_mock): + from testgen.mcp.tools.quality_scores import get_quality_scores + + with _patch_perms(), pytest.raises(MCPUserError, match="too long"): + get_quality_scores( + project_code="demo", + filters=[{"field": "Business Domain", "value": "x" * 257}], + ) + + +def test_multiple_filter_problems_listed_at_once(db_session_mock): + """When several filter entries are bad, the error lists every offender.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + bad_filters = [ + {"field": "Quality Dimension", "value": "Accuracy"}, # not a filter field + {"field": "Business Domain", "value": "x';--"}, # bad chars + {"field": "Data Source", "value": ""}, # empty value + ] + with _patch_perms(), pytest.raises(MCPUserError) as exc_info: + get_quality_scores(project_code="demo", filters=bad_filters) + + msg = str(exc_info.value) + assert "Quality Dimension" in msg + assert "Business Domain" in msg + assert "Data Source" in msg + + +def test_project_not_accessible_rejected(db_session_mock): + """A project the user can't view raises MCPResourceNotAccessible-style error.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + with _patch_perms(allowed=("only_this",)), pytest.raises(MCPResourceNotAccessible, match="forbidden_proj"): + get_quality_scores(project_code="forbidden_proj") + + +# --- Score-type → model-call mapping --- + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_default_overall_shows_both_combined_and_cde(mock_definition_cls, db_session_mock): + """score_type omitted → both Combined and CDE Score lines are rendered.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.93, cde_score=0.81) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + + assert "Combined Score" in out + assert "93" in out + assert "CDE Score" in out + assert "81" in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_combined_overall_shows_only_combined(mock_definition_cls, db_session_mock): + """score_type='Combined' renders only the Combined Score line.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.93, cde_score=None) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="Combined", + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + + assert "Combined Score" in out + assert "93" in out + assert "CDE Score" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_cde_overall_shows_only_cde(mock_definition_cls, db_session_mock): + """score_type='CDE' renders only the CDE Score line.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=None, cde_score=0.81) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="CDE", + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + + assert "CDE Score" in out + assert "81" in out + assert "Combined Score" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_combined_grouped_uses_breakdown(mock_definition_cls, db_session_mock): + """score_type='Combined' + group_by sources per-category rows from breakdown. + + Per-category output always includes Impact (matching the Score Explorer UI), + so the tool reads from get_score_card_breakdown rather than card.categories. + """ + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition.get_score_card_breakdown.return_value = [ + {"business_domain": "Finance", "score": 0.91, "issue_ct": 4, "impact": 3.2}, + {"business_domain": "Marketing", "score": 0.74, "issue_ct": 11, "impact": 9.8}, + ] + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="Combined", + group_by="Business Domain", + filters=[{"field": "Data Source", "value": "warehouse"}], + include_impact=True, + ) + + mock_definition.get_score_card_breakdown.assert_called_once_with("score", "business_domain") + assert "Finance" in out + assert "Marketing" in out + assert "Impact on Combined Score" in out + assert "Impact on CDE Score" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_cde_grouped_uses_breakdown(mock_definition_cls, db_session_mock): + """score_type='CDE' + group_by sources per-category rows from breakdown.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=None, cde_score=0.72) + mock_definition.get_score_card_breakdown.return_value = [ + {"business_domain": "Finance", "score": 0.80, "issue_ct": 2, "impact": 1.5}, + ] + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="CDE", + group_by="Business Domain", + filters=[{"field": "Data Source", "value": "warehouse"}], + include_impact=True, + ) + + mock_definition.get_score_card_breakdown.assert_called_once_with("cde_score", "business_domain") + assert "Finance" in out + assert "Impact on CDE Score" in out + assert "Impact on Combined Score" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_default_grouped_renders_both_score_columns(mock_definition_cls, db_session_mock): + """score_type omitted + group_by → table has Combined + CDE columns and + Impact columns for both, populated from two breakdown calls. + """ + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9, cde_score=0.7) + + breakdown_results = { + "score": [ + {"business_domain": "Finance", "score": 0.91, "issue_ct": 4, "impact": 3.2}, + {"business_domain": "Marketing", "score": 0.74, "issue_ct": 12, "impact": 11.4}, + ], + "cde_score": [ + {"business_domain": "Finance", "score": 0.85, "issue_ct": 3, "impact": 5.0}, + {"business_domain": "Marketing", "score": 0.60, "issue_ct": 8, "impact": 12.0}, + ], + } + mock_definition.get_score_card_breakdown.side_effect = ( + lambda score_key, _col: breakdown_results[score_key] + ) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + group_by="Business Domain", + filters=[{"field": "Data Source", "value": "warehouse"}], + include_impact=True, + ) + + # Both score types → two breakdown calls + assert mock_definition.get_score_card_breakdown.call_count == 2 + call_keys = {c.args[0] for c in mock_definition.get_score_card_breakdown.call_args_list} + assert call_keys == {"score", "cde_score"} + + assert "Combined Score" in out + assert "CDE Score" in out + assert "Impact on Combined Score" in out + assert "Impact on CDE Score" in out + assert "Finance" in out + assert "Marketing" in out + + +# --- include_impact --- + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_include_impact_default_false_omits_impact_columns(mock_definition_cls, db_session_mock): + """Default include_impact=False → grouped output has no Impact columns.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9, cde_score=0.7) + breakdown_results = { + "score": [{"business_domain": "Finance", "score": 0.91, "issue_ct": 4, "impact": 3.2}], + "cde_score": [{"business_domain": "Finance", "score": 0.85, "issue_ct": 3, "impact": 5.0}], + } + mock_definition.get_score_card_breakdown.side_effect = ( + lambda score_key, _col: breakdown_results[score_key] + ) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + group_by="Business Domain", + filters=[{"field": "Data Source", "value": "wh"}], + ) + + assert "Finance" in out + assert "Combined Score" in out + assert "CDE Score" in out + assert "Impact" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_include_impact_false_combined_only_omits_impact_column(mock_definition_cls, db_session_mock): + """Combined-only + default include_impact=False → no impact column.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition.get_score_card_breakdown.return_value = [ + {"business_domain": "Finance", "score": 0.91, "issue_ct": 4, "impact": 3.2}, + ] + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="Combined", + group_by="Business Domain", + filters=[{"field": "Data Source", "value": "wh"}], + ) + + assert "Finance" in out + assert "Combined Score" in out + assert "Impact" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_include_impact_false_cde_only_omits_impact_column(mock_definition_cls, db_session_mock): + """CDE-only + default include_impact=False → no impact column.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=None, cde_score=0.7) + mock_definition.get_score_card_breakdown.return_value = [ + {"business_domain": "Finance", "score": 0.8, "issue_ct": 3, "impact": 2.0}, + ] + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="CDE", + group_by="Business Domain", + filters=[{"field": "Data Source", "value": "wh"}], + ) + + assert "Finance" in out + assert "CDE Score" in out + assert "Impact" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_include_impact_false_overall_unaffected(mock_definition_cls, db_session_mock): + """include_impact only affects grouped output — overall block is unchanged.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.93, cde_score=0.81) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out_default = get_quality_scores( + project_code="demo", + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + out_with_impact = get_quality_scores( + project_code="demo", + include_impact=True, + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + + # No group_by → impact has no rendering surface either way. + assert "Impact" not in out_default + assert "Impact" not in out_with_impact + + +# --- include_issue_ct --- + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_include_issue_ct_overall_calls_get_overall_issue_ct(mock_definition_cls, db_session_mock): + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition.get_overall_issue_ct.return_value = 42 + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + include_issue_ct=True, + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + + mock_definition.get_overall_issue_ct.assert_called_once_with() + assert "Issue Count" in out + assert "42" in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_include_issue_ct_grouped_combined_uses_simple_label(mock_definition_cls, db_session_mock): + """grouped + Combined + include_issue_ct: single 'Issue Count' column header.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition.get_score_card_breakdown.return_value = [ + {"business_domain": "Finance", "score": 0.91, "issue_ct": 7, "impact": 4.0}, + ] + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="Combined", + group_by="Business Domain", + include_issue_ct=True, + filters=[{"field": "Data Source", "value": "wh"}], + ) + + mock_definition.get_score_card_breakdown.assert_called_once_with("score", "business_domain") + assert "Finance" in out + assert "7" in out + assert "Issue Count" in out + assert "Issue Count (Combined)" not in out + assert "Issue Count (CDE)" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_include_issue_ct_grouped_cde_uses_simple_label(mock_definition_cls, db_session_mock): + """grouped + CDE + include_issue_ct: single 'Issue Count' column header.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=None, cde_score=0.7) + mock_definition.get_score_card_breakdown.return_value = [ + {"business_domain": "Finance", "score": 0.8, "issue_ct": 3, "impact": 2.0}, + ] + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="CDE", + group_by="Business Domain", + include_issue_ct=True, + filters=[{"field": "Data Source", "value": "wh"}], + ) + + mock_definition.get_score_card_breakdown.assert_called_once_with("cde_score", "business_domain") + assert "Finance" in out + assert "3" in out + assert "Issue Count" in out + assert "Issue Count (Combined)" not in out + assert "Issue Count (CDE)" not in out + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_include_issue_ct_grouped_default_uses_parenthetical_labels(mock_definition_cls, db_session_mock): + """grouped + score_type unset + include_issue_ct: separate Combined / CDE + issue-count columns, and both Impact columns.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9, cde_score=0.7) + breakdown_results = { + "score": [{"business_domain": "Finance", "score": 0.91, "issue_ct": 7, "impact": 4.0}], + "cde_score": [{"business_domain": "Finance", "score": 0.80, "issue_ct": 3, "impact": 2.0}], + } + mock_definition.get_score_card_breakdown.side_effect = ( + lambda score_key, _col: breakdown_results[score_key] + ) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + group_by="Business Domain", + include_issue_ct=True, + include_impact=True, + filters=[{"field": "Data Source", "value": "wh"}], + ) + + assert mock_definition.get_score_card_breakdown.call_count == 2 + assert "Issue Count (Combined)" in out + assert "Issue Count (CDE)" in out + assert "Impact on Combined Score" in out + assert "Impact on CDE Score" in out + # Both per-category issue counts must appear, not just one + assert "7" in out # combined count + assert "3" in out # cde count + + +# --- Filter semantics passed to the model --- + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinitionCriteria") +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_filters_passed_to_from_filters(mock_definition_cls, mock_criteria_cls, db_session_mock): + """User filters are handed to ScoreDefinitionCriteria.from_filters.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + get_quality_scores( + project_code="demo", + filters=[ + {"field": "Business Domain", "value": "Finance"}, + {"field": "Business Domain", "value": "Marketing"}, + {"field": "Data Source", "value": "warehouse"}, + ], + ) + + # from_filters receives the translated DB column names — the parser + # converts user-facing labels to internal column names before this call. + mock_criteria_cls.from_filters.assert_called_once() + args, kwargs = mock_criteria_cls.from_filters.call_args + passed = args[0] + assert {"field": "business_domain", "value": "Finance"} in passed + assert {"field": "business_domain", "value": "Marketing"} in passed + assert {"field": "data_source", "value": "warehouse"} in passed + assert kwargs.get("group_by_field") is True + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinitionCriteria") +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +@patch("testgen.mcp.tools.common.TableGroup") +def test_table_group_adds_implicit_name_filter( + mock_tg_cls, mock_definition_cls, mock_criteria_cls, db_session_mock, +): + """When table_group_id is passed, the resolved TG's name is added as a filter.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + tg = MagicMock() + tg.id = uuid4() + tg.project_code = "demo" + tg.table_groups_name = "orders" + mock_tg_cls.get.return_value = tg + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + get_quality_scores(table_group_id=str(tg.id)) + + args, _ = mock_criteria_cls.from_filters.call_args + passed = args[0] + assert {"field": "table_groups_name", "value": "orders"} in passed + + +# --- Cross-project loop --- + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_cross_project_renders_per_project_sections(mock_definition_cls, db_session_mock): + """No project_code, no table_group_id → one H2 section per accessible project.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition_cls.return_value = mock_definition + + # Pass at least one filter so the tool doesn't fall into the + # "enumerate every table group in the project" branch (which would need + # `TableGroup.select_minimal_where` mocked). + with _patch_perms(allowed=("proj_a", "proj_b")): + out = get_quality_scores( + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + + assert "proj_a" in out + assert "proj_b" in out + # `as_score_card` should have been called once per project. + assert mock_definition.as_score_card.call_count == 2 + + +@patch("testgen.mcp.tools.quality_scores.TableGroup") +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_unfiltered_project_enumerates_table_groups(mock_definition_cls, mock_tg_cls, db_session_mock): + """Unfiltered project_code call enumerates table groups so as_score_card's + has_filters() gate passes (mirrors the score-explorer UI default).""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + tg1 = MagicMock() + tg1.table_groups_name = "orders" + tg2 = MagicMock() + tg2.table_groups_name = "customers" + mock_tg_cls.select_minimal_where.return_value = [tg1, tg2] + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + get_quality_scores(project_code="demo") + + # Verify TableGroup.select_minimal_where was called for enumeration. + mock_tg_cls.select_minimal_where.assert_called_once() + + +# --- Row cap --- + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_grouped_row_cap_truncates_and_footers(mock_definition_cls, db_session_mock): + """At >_ROW_CAP category rows, render only top N and surface the cap in a footer.""" + from testgen.mcp.tools.quality_scores import _ROW_CAP, get_quality_scores + + breakdown_rows = [ + {"business_domain": f"L{i}", "score": 0.5 + i * 0.001, "issue_ct": 1, "impact": 0.1} + for i in range(_ROW_CAP + 10) + ] + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition.get_score_card_breakdown.return_value = breakdown_rows + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="Combined", + group_by="Business Domain", + filters=[{"field": "Data Source", "value": "wh"}], + ) + + assert f"Showing top {_ROW_CAP}" in out + assert str(_ROW_CAP + 10) in out + + +# --- Transient definition is never persisted --- + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_transient_definition_never_persisted(mock_definition_cls, db_session_mock): + """Hardening test: the MCP tool never calls .save() on its transient definition.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + get_quality_scores( + project_code="demo", + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + + mock_definition.save.assert_not_called() From 3818ff057ca3b524dfa3cd1b5a29736313cdeeb9 Mon Sep 17 00:00:00 2001 From: Luis Date: Fri, 15 May 2026 13:37:15 -0400 Subject: [PATCH 35/58] feat(mcp): add CRUD tools for quality scores - list_scorecards - get_scorecard - create_scorecard - delete_scorecard - update_scorecard --- .../run_refresh_score_cards_results.py | 47 + testgen/common/models/scores.py | 46 +- testgen/mcp/server.py | 14 +- testgen/mcp/services/inventory_service.py | 4 +- testgen/mcp/tools/common.py | 86 +- testgen/mcp/tools/quality_scores.py | 874 ++++++- testgen/ui/scripts/patch_streamlit.py | 1 - testgen/ui/views/score_explorer.py | 27 +- tests/unit/commands/test_score_cards.py | 133 +- .../common/models/test_score_definition.py | 143 ++ tests/unit/mcp/test_tools_common.py | 102 +- tests/unit/mcp/test_tools_quality_scores.py | 2256 ++++++++++++++++- 12 files changed, 3476 insertions(+), 257 deletions(-) diff --git a/testgen/commands/run_refresh_score_cards_results.py b/testgen/commands/run_refresh_score_cards_results.py index 3a6a71f3..ee8800f8 100644 --- a/testgen/commands/run_refresh_score_cards_results.py +++ b/testgen/commands/run_refresh_score_cards_results.py @@ -3,6 +3,7 @@ import time from testgen.common.models import get_current_session, with_database_session +from testgen.common.models.profiling_run import ProfilingRun from testgen.common.models.scores import ( SCORE_CATEGORIES, ScoreCard, @@ -11,6 +12,7 @@ ScoreDefinitionResult, ScoreDefinitionResultHistoryEntry, ) +from testgen.common.models.test_run import TestRun from testgen.common.notifications.score_drop import collect_score_notification_data, send_score_drop_notifications LOG = logging.getLogger("testgen") @@ -169,3 +171,48 @@ def run_recalculate_score_card(*, project_code: str, definition_id: str): project_code, round(end_time - start_time, 2), ) + + +@with_database_session +def save_and_refresh_score_definition( + score_definition: ScoreDefinition, + *, + is_new: bool, +) -> ScoreDefinition: + """Save a scorecard and refresh / recalculate its cached scores. + + Owns the persist-then-refresh orchestration shared by the Score Explorer UI + and the ``update_scorecard`` MCP tool. UI-only concerns (Streamlit cache + clear, navigation, toasts) stay in the view layer. + + For new scorecards (``is_new=True``), seeds the first refresh with a + history entry timestamped at the latest profiling or test run for the + project, so the trend chart has an anchor point. For existing scorecards, + also runs ``run_recalculate_score_card`` to update history entries whose + scores might have shifted under the new filters. + """ + refresh_kwargs: dict = {} + if is_new: + # tz-aware sentinel: run_time is stored as TIMESTAMP(timezone=True), so a naive + # min would raise on comparison when only one of the two runs exists. + epoch = datetime.datetime.min.replace(tzinfo=datetime.UTC) + latest_run = max( + ( + ProfilingRun.get_latest_run(score_definition.project_code), + TestRun.get_latest_run(score_definition.project_code), + ), + key=lambda run: getattr(run, "run_time", epoch), + ) + refresh_kwargs = { + "add_history_entry": True, + "refresh_date": latest_run.run_time if latest_run else None, + } + + score_definition.save() + run_refresh_score_cards_results(definition_id=score_definition.id, **refresh_kwargs) + if not is_new: + run_recalculate_score_card( + project_code=score_definition.project_code, + definition_id=score_definition.id, + ) + return score_definition diff --git a/testgen/common/models/scores.py b/testgen/common/models/scores.py index 803d7f7d..596361c3 100644 --- a/testgen/common/models/scores.py +++ b/testgen/common/models/scores.py @@ -29,7 +29,7 @@ text, ) from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import aliased, attributes, relationship +from sqlalchemy.orm import aliased, attributes, joinedload, relationship from testgen.common import read_template_sql_file from testgen.common.models import Base, get_current_session @@ -212,7 +212,14 @@ def list_with_table_group_targets( .order_by(ScoreDefinition.name) ) rows = get_current_session().execute(query).all() - return [(row.id, row.name, list(row.tg_names) if row.tg_names else []) for row in rows] + # Dedupe tg_names: a mode-2 scorecard with N chains under the same + # table_groups_name would otherwise list the name N times, causing the + # inventory tool to render the scorecard once per chain. dict.fromkeys + # preserves first-seen order. + return [ + (row.id, row.name, list(dict.fromkeys(row.tg_names)) if row.tg_names else []) + for row in rows + ] @classmethod def all( @@ -266,6 +273,40 @@ def all( return definitions + @classmethod + def list_for_project( + cls, + project_code: str, + page: int = 1, + limit: int = 20, + ) -> tuple[list[Self], int]: + """Paginated list of scorecards in a project. + + Returns ORM objects with ``criteria`` eager-loaded so callers can walk + the filter chain without firing extra queries. ``results`` is already + ``lazy="joined"`` and rides along automatically — feeds + ``as_cached_score_card()``. + """ + session = get_current_session() + base_filter = ScoreDefinition.project_code == project_code + + total = session.scalar( + select(func.count()).select_from( + select(ScoreDefinition.id).where(base_filter).subquery() + ) + ) or 0 + + query = ( + select(ScoreDefinition) + .options(joinedload(ScoreDefinition.criteria)) + .where(base_filter) + .order_by(ScoreDefinition.name) + .offset((page - 1) * limit) + .limit(limit) + ) + rows = session.scalars(query).unique().all() + return list(rows), total + def save(self) -> None: db_session = get_current_session() db_session.add(self) @@ -828,7 +869,6 @@ def add_as_cutoff(self): Query templates: add_latest_runs.sql """ - # ruff: noqa: RUF027 query = read_template_sql_file("add_latest_runs.sql", sub_directory="score_cards") params = { "project_code": self.definition.project_code, diff --git a/testgen/mcp/server.py b/testgen/mcp/server.py index 1e77724d..36f71445 100644 --- a/testgen/mcp/server.py +++ b/testgen/mcp/server.py @@ -167,7 +167,14 @@ def build_mcp_server( list_profiling_summaries, search_columns, ) - from testgen.mcp.tools.quality_scores import get_quality_scores + from testgen.mcp.tools.quality_scores import ( + create_scorecard, + delete_scorecard, + get_quality_scores, + get_scorecard, + list_scorecards, + update_scorecard, + ) from testgen.mcp.tools.reference import ( column_profile_fields_resource, get_test_type, @@ -280,6 +287,11 @@ def safe_prompt(fn): safe_tool(update_schedule) safe_tool(delete_schedule) safe_tool(get_quality_scores) + safe_tool(list_scorecards) + safe_tool(get_scorecard) + safe_tool(create_scorecard) + safe_tool(update_scorecard) + safe_tool(delete_scorecard) # Resources safe_resource("testgen://test-types", test_types_resource) diff --git a/testgen/mcp/services/inventory_service.py b/testgen/mcp/services/inventory_service.py index d744c89c..cee1c7a6 100644 --- a/testgen/mcp/services/inventory_service.py +++ b/testgen/mcp/services/inventory_service.py @@ -221,13 +221,13 @@ def _profiling_summary_fragment(summary: TableGroupSummary) -> str: + (summary.latest_hygiene_issues_likely_ct or 0) + (summary.latest_hygiene_issues_possible_ct or 0) ) - combined = friendly_score(score(summary.dq_score_profiling, summary.dq_score_testing)) + total = friendly_score(score(summary.dq_score_profiling, summary.dq_score_testing)) profiled_at = ( summary.latest_profile_start.strftime("%Y-%m-%d") if summary.latest_profile_start else "—" ) return ( - f"Score {combined}, hygiene issues {hygiene_issue_total}, " + f"Score {total}, hygiene issues {hygiene_issue_total}, " f"last profiled {profiled_at}, " f"profiling run `{summary.latest_profile_job_execution_id}`" ) diff --git a/testgen/mcp/tools/common.py b/testgen/mcp/tools/common.py index 7e5ed394..1da05a6b 100644 --- a/testgen/mcp/tools/common.py +++ b/testgen/mcp/tools/common.py @@ -17,6 +17,7 @@ from testgen.common.models.hygiene_issue import HygieneIssueType from testgen.common.models.profiling_run import ProfilingRun from testgen.common.models.scheduler import SCHEDULABLE_JOB_KEYS, JobSchedule +from testgen.common.models.scores import ScoreCategory, ScoreDefinition from testgen.common.models.table_group import TableGroup from testgen.common.models.test_definition import TestDefinition, TestType from testgen.common.models.test_result import TestResultStatus @@ -47,6 +48,7 @@ class DocGroup(StrEnum): INVESTIGATE = "Investigate quality issues" BROWSE_PROFILING = "Browse profiling results" TRIGGER = "Trigger profiling, tests, and test generation" + SCORING = "Track data quality scores" def parse_uuid(value: str, label: str = "ID") -> UUID: @@ -167,21 +169,63 @@ class ScoreFilterField(StrEnum): } -class ScoreType(StrEnum): - """User-facing values accepted for the ``score_type`` argument.""" +class ScoreCategoryArg(StrEnum): + """User-facing values accepted for the ``category`` argument on scorecard CRUD. - COMBINED = "Combined" - CDE = "CDE" + Same shape as ``ScoreGroupBy`` — every group-by value is also a valid + breakdown category. Kept as a separate enum (rather than reusing + ``ScoreGroupBy``) so each argument has its own valid-value set per the + per-arg enum convention. + """ + TABLE_GROUP = "Table Group" + DATA_LOCATION = "Data Location" + DATA_SOURCE = "Data Source" + SOURCE_SYSTEM = "Source System" + SOURCE_PROCESS = "Source Process" + BUSINESS_DOMAIN = "Business Domain" + STAKEHOLDER_GROUP = "Stakeholder Group" + TRANSFORM_LEVEL = "Transform Level" + QUALITY_DIMENSION = "Quality Dimension" + IMPACT_DIMENSION = "Impact Dimension" + DATA_PRODUCT = "Data Product" -# Translates to the internal sentinel consumed by ``ScoreDefinition.total_score`` -# / ``cde_score`` flag logic. -SCORE_TYPE_TO_INTERNAL: dict[ScoreType, str] = { - ScoreType.COMBINED: "total", - ScoreType.CDE: "cde", + +SCORE_CATEGORY_ARG_TO_COLUMN: dict[ScoreCategoryArg, str] = { + ScoreCategoryArg.TABLE_GROUP: "table_groups_name", + ScoreCategoryArg.DATA_LOCATION: "data_location", + ScoreCategoryArg.DATA_SOURCE: "data_source", + ScoreCategoryArg.SOURCE_SYSTEM: "source_system", + ScoreCategoryArg.SOURCE_PROCESS: "source_process", + ScoreCategoryArg.BUSINESS_DOMAIN: "business_domain", + ScoreCategoryArg.STAKEHOLDER_GROUP: "stakeholder_group", + ScoreCategoryArg.TRANSFORM_LEVEL: "transform_level", + ScoreCategoryArg.QUALITY_DIMENSION: "dq_dimension", + ScoreCategoryArg.IMPACT_DIMENSION: "impact_dimension", + ScoreCategoryArg.DATA_PRODUCT: "data_product", } +class ScoreChainLeafField(StrEnum): + """User-facing values accepted as the leaf ``field`` in a scorecard filter chain.""" + + TABLE = "Table" + COLUMN = "Column" + + +SCORE_CHAIN_LEAF_TO_COLUMN: dict[ScoreChainLeafField, str] = { + ScoreChainLeafField.TABLE: "table_name", + ScoreChainLeafField.COLUMN: "column_name", +} + + +class ScoreType(StrEnum): + """User-facing values accepted for the ``score_type`` argument.""" + + TOTAL = "Total" + CDE = "CDE" + + def parse_score_group_by(value: str) -> ScoreGroupBy: try: return ScoreGroupBy(value) @@ -210,6 +254,20 @@ def parse_score_type(value: str) -> ScoreType: raise MCPUserError(f"Invalid score_type `{value}`. Valid values: {valid}") from err +def parse_category(value: str) -> ScoreCategory: + """Validate a ``category`` argument and return the stored ``ScoreCategory``. + + Accepts the display-form values exposed by ``get_quality_scores``'s + ``group_by`` argument (e.g. ``Quality Dimension``, ``Data Source``). + """ + try: + arg = ScoreCategoryArg(value) + except ValueError as err: + valid = ", ".join(c.value for c in ScoreCategoryArg) + raise MCPUserError(f"Invalid category `{value}`. Valid values: {valid}") from err + return ScoreCategory(SCORE_CATEGORY_ARG_TO_COLUMN[arg]) + + # Maps user-facing run-status labels to underlying ``JobStatus`` values. Transient states # (Starting/Canceling) are excluded because they're sub-second and noisy as filters. # ``Pending`` collapses PENDING+CLAIMED; ``Canceled`` collapses CANCEL_REQUESTED+CANCELED. @@ -483,6 +541,16 @@ def resolve_profiling_run(job_execution_id: str) -> ProfilingRun: return run +def resolve_scorecard(scorecard_id: str) -> ScoreDefinition: + """Resolve a scorecard ID, collapsing missing-or-inaccessible into one error path.""" + parse_uuid(scorecard_id, "scorecard_id") + perms = get_project_permissions() + sd = ScoreDefinition.get(scorecard_id) + if sd is None or not perms.has_access(sd.project_code): + raise MCPResourceNotAccessible("Scorecard", scorecard_id) + return sd + + def resolve_test_definition(test_definition_id: str) -> TestDefinition: """Resolve a test definition ID to the live ORM model, collapsing missing-or-inaccessible. diff --git a/testgen/mcp/tools/quality_scores.py b/testgen/mcp/tools/quality_scores.py index c91c90f3..1545d7f4 100644 --- a/testgen/mcp/tools/quality_scores.py +++ b/testgen/mcp/tools/quality_scores.py @@ -1,27 +1,43 @@ +from collections import defaultdict + +from testgen.commands.run_refresh_score_cards_results import save_and_refresh_score_definition from testgen.common.models import with_database_session from testgen.common.models.scores import ( ScoreCategory, ScoreDefinition, + ScoreDefinitionBreakdownItem, ScoreDefinitionCriteria, + ScoreDefinitionFilter, ) from testgen.common.models.table_group import TableGroup from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.permissions import get_project_permissions, mcp_permission from testgen.mcp.tools.common import ( + SCORE_CHAIN_LEAF_TO_COLUMN, SCORE_FILTER_FIELD_TO_COLUMN, SCORE_GROUP_BY_TO_COLUMN, DocGroup, + ScoreChainLeafField, + ScoreFilterField, ScoreGroupBy, ScoreType, - parse_score_filter_field, + format_page_footer, + format_page_info, + parse_category, parse_score_group_by, parse_score_type, + resolve_scorecard, resolve_table_group, + validate_limit, + validate_page, ) from testgen.mcp.tools.markdown import MdDoc from testgen.utils import friendly_score, friendly_score_impact -_DOC_GROUP = DocGroup.DISCOVER +_DOC_GROUP = DocGroup.SCORING + +_DEFAULT_LIMIT = 20 +_MAX_LIMIT = 100 _VALUE_MAX_LEN = 256 _VALUE_FORBIDDEN_CHARS = frozenset("'\";\\\x00") @@ -31,9 +47,21 @@ # (≤ ~15 dimensions/domains), but pathological metadata could blow this up. _ROW_CAP = 100 -_COMBINED_LABEL = "Combined Score" +_TOTAL_LABEL = "Total Score" _CDE_LABEL = "CDE Score" +_COLUMN_TO_LABEL: dict[str, str] = { + column: group_by.value for group_by, column in SCORE_GROUP_BY_TO_COLUMN.items() +} +# Chain-only fields (mode 2): not exposed as standalone filter fields but valid +# as the leaves of a `table_groups_name → table_name → column_name` chain. +_COLUMN_TO_LABEL["table_name"] = "Table" +_COLUMN_TO_LABEL["column_name"] = "Column" + + +_CHAIN_ROOT_FIELD = ScoreFilterField.TABLE_GROUP.value # "Table Group" +_CHAIN_LEAF_FIELDS = tuple(f.value for f in ScoreChainLeafField) # ("Table", "Column") + @with_database_session @mcp_permission("view") @@ -49,29 +77,44 @@ def get_quality_scores( ) -> str: """Quality-score rollup with optional grouping and filtering. + Returns overall Total, CDE, Profiling, and Testing scores by default, + plus an optional breakdown table when ``group_by`` is set. Scope is + project-wide unless ``project_code`` or ``table_group_id`` narrows it. + + **Filters.** Each filter is + ``{"field": "...", "value": "...", "others"?: [...]}``. Same-field values + OR together; different fields AND together. Valid flat fields: + ``"Table Group"``, ``"Data Location"``, ``"Data Source"``, + ``"Source System"``, ``"Source Process"``, ``"Business Domain"``, + ``"Stakeholder Group"``, ``"Transform Level"``, ``"Semantic Data Type"``, + ``"Data Product"``. To target specific tables or columns, chain a + ``"Table Group"`` filter via ``others`` into ``"Table"`` (optionally + then ``"Column"``); sibling chains OR. ``"Impact Dimension"`` and + ``"Quality Dimension"`` are valid as ``group_by`` only, not as filter + fields. Filter values must not contain quotes, semicolons, or + backslashes. ``table_group_id`` cannot be combined with chained + filters — put ``"Table Group"`` in the chain root instead. + Args: - project_code: Scope to a single project. Omit to roll across every - project the caller can view. - table_group_id: Scope to a single table group, e.g. from - ``get_data_inventory``. Mutually exclusive with ``project_code``. - group_by: One of 'Quality Dimension', 'Impact Dimension', - 'Semantic Data Type', 'Table Group', 'Data Location', - 'Data Source', 'Source System', 'Source Process', - 'Business Domain', 'Stakeholder Group', 'Transform Level', - 'Data Product'. Omit for the unfiltered overall score. - score_type: Narrows which score(s) are reported. Omit (default) to - show both Combined and CDE; pass 'Combined' to show only the - Combined Score, or 'CDE' to show only the CDE Score. - filters: List of {"field": str, "value": str} pairs. Same-field - filters OR together; different fields AND together. Valid fields - are the same as ``group_by`` except 'Quality Dimension' and - 'Impact Dimension', which are valid as ``group_by`` only. Filter - values must not contain quotes, semicolons, or backslashes. - include_issue_ct: When True, include the count of contributing issues + project_code: Scope to a project. Mutually exclusive with + ``table_group_id``. Omit both to roll across every visible + project. + table_group_id: Scope to a table group, e.g. from + ``get_data_inventory``. + group_by: Break overall scores out by one of: ``"Impact Dimension"``, + ``"Quality Dimension"``, ``"Semantic Data Type"``, + ``"Table Group"``, ``"Data Location"``, ``"Data Source"``, + ``"Source System"``, ``"Source Process"``, ``"Business Domain"``, + ``"Stakeholder Group"``, ``"Transform Level"``, + ``"Data Product"``. + score_type: Narrow returned scores. Omit to show all four (Total, + CDE, Profiling, Testing); pass ``"Total"`` for Total + Profiling + + Testing, or ``"CDE"`` for CDE alone. + filters: List of filter entries. See **Filters** above for shape. + include_issue_ct: Include the count of contributing issues (hygiene + test failures). - include_impact: When True, include the per-category impact on the - overall score (the percentage contribution to total quality - loss). Only affects grouped output. + include_impact: Include the per-category percentage impact on the + overall score. Only affects grouped output. """ perms = get_project_permissions() @@ -87,7 +130,13 @@ def get_quality_scores( parse_score_group_by(group_by) if group_by is not None else None ) - user_filters = _validate_filters(filters) + user_filters, group_by_field = _validate_filters(filters, allow_empty=True) + + if table_group_id is not None and not group_by_field: + raise MCPUserError( + "`table_group_id` cannot be combined with chained filters — " + "put `Table Group` in the chain root instead." + ) if table_group_id is not None: table_group = resolve_table_group(table_group_id) @@ -124,6 +173,7 @@ def get_quality_scores( group_by=parsed_group_by, score_type=parsed_score_type, user_filters=user_filters, + group_by_field=group_by_field, include_issue_ct=include_issue_ct, include_impact=include_impact, heading=code if cross_project else None, @@ -132,49 +182,6 @@ def get_quality_scores( return doc.render() -def _validate_filters(filters: list[dict] | None) -> list[dict]: - """Validate filter dicts and translate ``field`` from user labels to internal DB columns.""" - if not filters: - return [] - errors: list[str] = [] - cleaned: list[dict] = [] - for i, entry in enumerate(filters): - if not isinstance(entry, dict): - errors.append(f"entry {i}: must be a dict with `field` and `value`") - continue - field = entry.get("field") - value = entry.get("value") - if not field: - errors.append(f"entry {i}: missing `field`") - continue - if value is None or value == "": - errors.append(f"entry {i} ({field!r}): empty value") - continue - try: - parsed_field = parse_score_filter_field(field) - except MCPUserError as err: - errors.append(f"entry {i}: {err}") - continue - if not isinstance(value, str): - errors.append(f"entry {i} ({field!r}): value must be a string") - continue - if len(value) > _VALUE_MAX_LEN: - errors.append( - f"entry {i} ({field!r}): value too long ({len(value)} > {_VALUE_MAX_LEN})" - ) - continue - bad_chars = sorted(set(value) & _VALUE_FORBIDDEN_CHARS) - if bad_chars: - errors.append( - f"entry {i} ({field!r}): value contains forbidden characters {bad_chars}" - ) - continue - cleaned.append({"field": SCORE_FILTER_FIELD_TO_COLUMN[parsed_field], "value": value}) - if errors: - raise MCPUserError("Invalid filters: " + "; ".join(errors)) - return cleaned - - def _build_definition( *, project_code: str, @@ -182,6 +189,7 @@ def _build_definition( group_by: ScoreGroupBy | None, score_type: ScoreType | None, user_filters: list[dict], + group_by_field: bool, ) -> ScoreDefinition: definition = ScoreDefinition() definition.project_code = project_code @@ -190,7 +198,7 @@ def _build_definition( # `as_score_card` derives `cde_only_categories = cde_score and not # total_score` — so flag combinations decide whether the category SQL # filters by `critical_data_element = true`. - definition.total_score = score_type is None or score_type is ScoreType.COMBINED + definition.total_score = score_type is None or score_type is ScoreType.TOTAL definition.cde_score = score_type is None or score_type is ScoreType.CDE definition.category = ( ScoreCategory(SCORE_GROUP_BY_TO_COLUMN[group_by]) if group_by is not None else None @@ -219,7 +227,7 @@ def _build_definition( ) definition.criteria = ScoreDefinitionCriteria.from_filters( - filters, group_by_field=True, + filters, group_by_field=group_by_field, ) return definition @@ -232,6 +240,7 @@ def _render_one_scope( group_by: ScoreGroupBy | None, score_type: ScoreType | None, user_filters: list[dict], + group_by_field: bool, include_issue_ct: bool, include_impact: bool, heading: str | None, @@ -245,16 +254,20 @@ def _render_one_scope( group_by=group_by, score_type=score_type, user_filters=user_filters, + group_by_field=group_by_field, ) - show_combined = score_type is None or score_type is ScoreType.COMBINED + show_total = score_type is None or score_type is ScoreType.TOTAL show_cde = score_type is None or score_type is ScoreType.CDE card = definition.as_score_card() - if show_combined: - doc.field(_COMBINED_LABEL, friendly_score(card.get("score"))) + if show_total: + doc.field(_TOTAL_LABEL, friendly_score(card.get("score"))) if show_cde: doc.field(_CDE_LABEL, friendly_score(card.get("cde_score"))) + if show_total: + doc.field("Profiling Score", friendly_score(card.get("profiling_score"))) + doc.field("Testing Score", friendly_score(card.get("testing_score"))) if include_issue_ct and group_by is None: doc.field("Issue Count", definition.get_overall_issue_ct()) @@ -266,47 +279,50 @@ def _render_one_scope( # Per-category data — score, impact, issue_ct — comes from # get_score_card_breakdown. One call per enabled score type, since each - # filters different rows (Combined includes all data points; CDE filters + # filters different rows (Total includes all data points; CDE filters # to critical_data_element=true). - combined_rows: dict[str, dict] = {} + total_rows: dict[str, dict] = {} cde_rows: dict[str, dict] = {} - if show_combined: + if show_total: for r in definition.get_score_card_breakdown("score", group_by_column): label = r.get(group_by_column) if label is not None: - combined_rows[label] = r + total_rows[label] = r if show_cde: for r in definition.get_score_card_breakdown("cde_score", group_by_column): label = r.get(group_by_column) if label is not None: cde_rows[label] = r - all_labels = set(combined_rows) | set(cde_rows) + all_labels = set(total_rows) | set(cde_rows) if not all_labels: - doc.text("_No category data._") + if user_filters: + doc.text("_Filter matched no data._") + else: + doc.text("_No category data._") return - # Worst score first. Sort by primary column (Combined if shown, else CDE). + # Worst score first. Sort by primary column (Total if shown, else CDE). def _sort_key(label: str) -> float: - primary = combined_rows if show_combined else cde_rows + primary = total_rows if show_total else cde_rows score = (primary.get(label) or {}).get("score") return score if score is not None else 1.0 sorted_labels = sorted(all_labels, key=_sort_key) - total_rows = len(sorted_labels) + row_count = len(sorted_labels) capped = sorted_labels[:_ROW_CAP] - both_shown = show_combined and show_cde - combined_issue_header = "Issue Count (Combined)" if both_shown else "Issue Count" + both_shown = show_total and show_cde + total_issue_header = "Issue Count (Total)" if both_shown else "Issue Count" cde_issue_header = "Issue Count (CDE)" if both_shown else "Issue Count" headers: list[str] = [group_by.value] - if show_combined: - headers.append(_COMBINED_LABEL) + if show_total: + headers.append(_TOTAL_LABEL) if include_impact: - headers.append("Impact on Combined Score") + headers.append("Impact on Total Score") if include_issue_ct: - headers.append(combined_issue_header) + headers.append(total_issue_header) if show_cde: headers.append(_CDE_LABEL) if include_impact: @@ -317,9 +333,9 @@ def _sort_key(label: str) -> float: md_rows: list[list[object]] = [] for label in capped: cells: list[object] = [label] - c_row = combined_rows.get(label) or {} + c_row = total_rows.get(label) or {} d_row = cde_rows.get(label) or {} - if show_combined: + if show_total: cells.append(friendly_score(c_row.get("score"))) if include_impact: cells.append(_format_impact(c_row.get("impact"))) @@ -334,14 +350,682 @@ def _sort_key(label: str) -> float: md_rows.append(cells) doc.table(headers, md_rows) - if total_rows > _ROW_CAP: - doc.text(f"_Showing top {_ROW_CAP} of {total_rows} rows by lowest score._") + if row_count > _ROW_CAP: + doc.text(f"_Showing top {_ROW_CAP} of {row_count} rows by lowest score._") + + +@with_database_session +@mcp_permission("view") +def list_scorecards( + project_code: str, + page: int = 1, + limit: int = _DEFAULT_LIMIT, +) -> str: + """List the scorecards defined in a project. + + Args: + project_code: Project to list scorecards for. + page: Page number, starting at 1. + limit: Page size (max 100). + """ + validate_page(page) + validate_limit(limit, _MAX_LIMIT) + + perms = get_project_permissions() + perms.verify_access( + project_code, + not_found=MCPResourceNotAccessible("Project", project_code), + ) + + definitions, total = ScoreDefinition.list_for_project( + project_code, page=page, limit=limit, + ) + + doc = MdDoc() + doc.heading(1, f"Scorecards in Project `{project_code}`") + + page_info = format_page_info(total, page, limit) + if page_info: + doc.text(page_info) + + if not definitions: + if page > 1: + doc.text(f"_No scorecards on page {page} (total: {total})._") + else: + doc.text("_No scorecards configured._") + return doc.render() + + for definition in definitions: + doc.heading(2, f"{definition.name} (id: `{definition.id}`)") + card = definition.as_cached_score_card() + if definition.total_score: + doc.field(_TOTAL_LABEL, friendly_score(card.get("score"))) + if definition.cde_score: + doc.field(_CDE_LABEL, friendly_score(card.get("cde_score"))) + if definition.total_score: + doc.field("Profiling Score", friendly_score(card.get("profiling_score"))) + doc.field("Testing Score", friendly_score(card.get("testing_score"))) + if definition.category is not None: + doc.field("Category", _column_label(definition.category.value)) + doc.field("Filters", _format_criteria_summary(definition.criteria)) + + footer = format_page_footer(total, page, limit) + if footer: + doc.text(footer) + + return doc.render() + + +@with_database_session +@mcp_permission("view") +def get_scorecard(scorecard_id: str) -> str: + """Get a scorecard with its current scores and per-category breakdown. + + Args: + scorecard_id: UUID returned by ``list_scorecards`` or ``get_data_inventory``. + """ + definition = resolve_scorecard(scorecard_id) + card = definition.as_cached_score_card() + + doc = MdDoc() + doc.heading(1, f"Scorecard: {definition.name}") + + doc.field("ID", definition.id, code=True) + doc.field("Project", definition.project_code, code=True) + if definition.total_score: + doc.field(_TOTAL_LABEL, friendly_score(card.get("score"))) + if definition.cde_score: + doc.field(_CDE_LABEL, friendly_score(card.get("cde_score"))) + if definition.total_score: + doc.field("Profiling Score", friendly_score(card.get("profiling_score"))) + doc.field("Testing Score", friendly_score(card.get("testing_score"))) + if definition.category is not None: + doc.field("Category", _column_label(definition.category.value)) + doc.field("Filters", _format_criteria_summary(definition.criteria)) + + if definition.category is not None: + _render_breakdown(doc, definition) + + return doc.render() + + +def _render_breakdown(doc: MdDoc, definition: ScoreDefinition) -> None: + """Render the per-category breakdown table for an enabled score_type pair. + + Total and CDE rows are merged by label so the same category value shows + on one line with both score_types. Sorted by primary-score-type impact + desc; capped at ``_ROW_CAP`` rows with a truncation footer when exceeded. + """ + category_column = definition.category.value + category_label = _column_label(category_column) + doc.heading(2, f"Breakdown by {category_label}") + + show_total = definition.total_score + show_cde = definition.cde_score + + total_rows: dict[str, dict] = {} + cde_rows: dict[str, dict] = {} + if show_total: + for item in ScoreDefinitionBreakdownItem.filter( + definition_id=definition.id, + category=category_column, + score_type="score", + ): + row = item.to_dict() + label = _row_label(row, category_column) + if label is not None: + total_rows[label] = row + if show_cde: + for item in ScoreDefinitionBreakdownItem.filter( + definition_id=definition.id, + category=category_column, + score_type="cde_score", + ): + row = item.to_dict() + label = _row_label(row, category_column) + if label is not None: + cde_rows[label] = row + + all_labels = set(total_rows) | set(cde_rows) + if not all_labels: + doc.text("_No breakdown data._") + return + + primary = total_rows if show_total else cde_rows + + def _sort_key(label: str) -> float: + impact = (primary.get(label) or {}).get("impact") + return impact if impact is not None else 0.0 + + # Highest impact first — same ordering as the cached rows from the model. + sorted_labels = sorted(all_labels, key=_sort_key, reverse=True) + row_count = len(sorted_labels) + capped = sorted_labels[:_ROW_CAP] + + both_shown = show_total and show_cde + total_issue_header = "Issue Count (Total)" if both_shown else "Issue Count" + cde_issue_header = "Issue Count (CDE)" if both_shown else "Issue Count" + + headers: list[str] = [category_label] + if show_total: + headers.extend([_TOTAL_LABEL, "Impact on Total Score", total_issue_header]) + if show_cde: + headers.extend([_CDE_LABEL, "Impact on CDE Score", cde_issue_header]) + + md_rows: list[list[object]] = [] + for label in capped: + cells: list[object] = [label] + c_row = total_rows.get(label) or {} + d_row = cde_rows.get(label) or {} + if show_total: + cells.append(friendly_score(c_row.get("score"))) + cells.append(_format_impact(c_row.get("impact"))) + cells.append(c_row.get("issue_ct") if c_row else None) + if show_cde: + cells.append(friendly_score(d_row.get("score"))) + cells.append(_format_impact(d_row.get("impact"))) + cells.append(d_row.get("issue_ct") if d_row else None) + md_rows.append(cells) + doc.table(headers, md_rows) + + if row_count > _ROW_CAP: + doc.text(f"_Showing top {_ROW_CAP} of {row_count} rows by highest impact._") + + +@with_database_session +@mcp_permission("edit") +def create_scorecard( + project_code: str, + name: str, + filters: list[dict], + *, + category: str | None = None, + show_total_score: bool = True, + show_cde_score: bool = False, +) -> str: + """Create a scorecard in a project. + + **Filters.** At least one filter is required. Each entry is + ``{"field": "...", "value": "...", "others"?: [...]}``. Same-field values + OR together; different fields AND together. Valid flat fields: + ``"Table Group"``, ``"Data Location"``, ``"Data Source"``, + ``"Source System"``, ``"Source Process"``, ``"Business Domain"``, + ``"Stakeholder Group"``, ``"Transform Level"``, ``"Semantic Data Type"``, + ``"Data Product"``. To target specific tables or columns, chain a + ``"Table Group"`` filter via ``others`` into ``"Table"`` (optionally + then ``"Column"``); sibling chains OR. + + Args: + project_code: Project that will own the scorecard. + name: Scorecard name. Must be non-empty. + filters: List of filter entries. See **Filters** above for shape. + category: Category for per-bucket breakdown. One of + ``"Quality Dimension"``, ``"Impact Dimension"``, + ``"Data Source"``, ``"Business Domain"``, ``"Stakeholder Group"``, + ``"Table Group"``, ``"Transform Level"``, ``"Data Location"``, + ``"Source System"``, ``"Source Process"``, ``"Data Product"``. + show_total_score: Whether the scorecard exposes the Total Score. + show_cde_score: Whether the scorecard exposes the CDE Score. + """ + perms = get_project_permissions() + perms.verify_access( + project_code, + not_found=MCPResourceNotAccessible("Project", project_code), + ) + + if not name.strip(): + raise MCPUserError("`name` must be non-empty.") + + parsed_filters, group_by_field = _validate_filters(filters) + category_value = parse_category(category) if category is not None else None + + definition = ScoreDefinition() + definition.project_code = project_code + definition.name = name + definition.total_score = show_total_score + definition.cde_score = show_cde_score + definition.category = category_value + definition.criteria = ScoreDefinitionCriteria.from_filters( + parsed_filters, + group_by_field=group_by_field, + ) + + save_and_refresh_score_definition(definition, is_new=True) + + doc = MdDoc() + doc.heading(1, f"Scorecard `{definition.name}` created") + doc.field("ID", definition.id, code=True) + doc.field("Project", definition.project_code, code=True) + doc.field(_TOTAL_LABEL, "Yes" if show_total_score else "No") + doc.field(_CDE_LABEL, "Yes" if show_cde_score else "No") + if category_value is not None: + doc.field("Category", _column_label(category_value.value)) + doc.field("Filters", _format_criteria_summary(definition.criteria)) + return doc.render() + + +@with_database_session +@mcp_permission("edit") +def update_scorecard( + scorecard_id: str, + *, + name: str | None = None, + show_total_score: bool | None = None, + show_cde_score: bool | None = None, + category: str | None = None, + filters: list[dict] | None = None, +) -> str: + """Update fields on an existing scorecard. Pass only the fields to change. + + **Filters.** When supplied, ``filters`` replaces the scorecard's filters + wholesale and at least one entry is required. Each entry is + ``{"field": "...", "value": "...", "others"?: [...]}``. Same-field values + OR together; different fields AND together. Valid flat fields: + ``"Table Group"``, ``"Data Location"``, ``"Data Source"``, + ``"Source System"``, ``"Source Process"``, ``"Business Domain"``, + ``"Stakeholder Group"``, ``"Transform Level"``, ``"Semantic Data Type"``, + ``"Data Product"``. To target specific tables or columns, chain a + ``"Table Group"`` filter via ``others`` into ``"Table"`` (optionally + then ``"Column"``); sibling chains OR. + + Args: + scorecard_id: UUID returned by ``list_scorecards`` or + ``get_data_inventory``. + name: New scorecard name. Must be non-empty when supplied. + show_total_score: Whether the scorecard exposes the Total Score. + show_cde_score: Whether the scorecard exposes the CDE Score. + category: Category for per-bucket breakdown. One of + ``"Quality Dimension"``, ``"Impact Dimension"``, + ``"Data Source"``, ``"Business Domain"``, ``"Stakeholder Group"``, + ``"Table Group"``, ``"Transform Level"``, ``"Data Location"``, + ``"Source System"``, ``"Source Process"``, ``"Data Product"``. + Pass ``""`` to clear an existing category. + filters: List of filter entries. See **Filters** above for shape. + """ + definition = resolve_scorecard(scorecard_id) + + new_category: ScoreCategory | None = None + clear_category = category == "" + if category is not None and not clear_category: + new_category = parse_category(category) + + parsed_filters: list[dict] | None = None + group_by_field: bool | None = None + if filters is not None: + parsed_filters, group_by_field = _validate_filters(filters) + + pending: dict = {} + if name is not None: + if not name.strip(): + raise MCPUserError("`name` must be non-empty.") + pending["name"] = name + if show_total_score is not None: + pending["total_score"] = show_total_score + if show_cde_score is not None: + pending["cde_score"] = show_cde_score + if new_category is not None: + pending["category"] = new_category + elif clear_category: + pending["category"] = None + if parsed_filters is not None: + pending["criteria"] = ScoreDefinitionCriteria.from_filters( + parsed_filters, + group_by_field=group_by_field, + ) + + if not pending: + raise MCPUserError("No fields supplied to update.") + + before = _snapshot_for_diff(definition, pending) + for attr, value in pending.items(): + setattr(definition, attr, value) + after = _snapshot_for_diff(definition, pending) + + save_and_refresh_score_definition(definition, is_new=False) + + doc = MdDoc() + doc.heading(1, f"Scorecard `{definition.name}` updated") + doc.field("ID", definition.id, code=True) + doc.field("Project", definition.project_code, code=True) + rows = [ + [_DIFF_LABELS[attr], before[attr], after[attr]] + for attr in pending + ] + doc.table(["Field", "Before", "After"], rows, code=[0]) + return doc.render() + + +_DIFF_LABELS: dict[str, str] = { + "name": "Name", + "total_score": _TOTAL_LABEL, + "cde_score": _CDE_LABEL, + "category": "Category", + "criteria": "Filters", +} + + +def _snapshot_for_diff(definition: ScoreDefinition, attrs: dict) -> dict[str, str | None]: + """Render display-form values for each attr being changed.""" + snapshot: dict[str, str | None] = {} + for attr in attrs: + value = getattr(definition, attr, None) + if attr == "category": + snapshot[attr] = _column_label(value.value) if value is not None else None + elif attr == "criteria": + snapshot[attr] = _format_criteria_summary(value) + elif isinstance(value, bool): + snapshot[attr] = "Yes" if value else "No" + else: + snapshot[attr] = value if value is not None else None + return snapshot + + +@with_database_session +@mcp_permission("edit") +def delete_scorecard(scorecard_id: str) -> str: + """Delete a scorecard. + + Args: + scorecard_id: UUID returned by ``list_scorecards`` or ``get_data_inventory``. + """ + definition = resolve_scorecard(scorecard_id) + name = definition.name + project_code = definition.project_code + deleted_id = definition.id + + definition.delete() + + doc = MdDoc() + doc.heading(1, f"Scorecard `{name}` deleted") + doc.field("ID", deleted_id, code=True) + doc.field("Project", project_code, code=True) + return doc.render() + +def _filter_value_errors(value: object, field: str) -> list[str]: + """Return error strings for an unsafe filter value (empty list if safe). + + Catches non-string types, over-length values, and forbidden characters + that would enable SQL injection via ``ScoreDefinitionCriteria.get_as_sql``. + Does not check for empty/missing values — callers handle that separately. + """ + if not isinstance(value, str): + return [f"({field!r}): value must be a string"] + errors: list[str] = [] + if len(value) > _VALUE_MAX_LEN: + errors.append(f"({field!r}): value too long ({len(value)} > {_VALUE_MAX_LEN})") + bad_chars = sorted(set(value) & _VALUE_FORBIDDEN_CHARS) + if bad_chars: + errors.append(f"({field!r}): value contains forbidden characters {bad_chars}") + return errors + + +def _validate_filters( + raw_filters: list[dict] | None, *, allow_empty: bool = False, +) -> tuple[list[dict], bool]: + """Validate user-supplied filter shape and translate to column-form storage. + + Returns ``(parsed_filters, group_by_field)``. Input ``field`` values are + display-form (e.g. ``"Table Group"``, ``"Data Source"``, ``"Table"``, + ``"Column"``); the returned dicts use the underlying DB column names + (e.g. ``"table_groups_name"``, ``"table_name"``). + + Two storage modes (selectable per call, not mutually exclusive across + callers): + + * Mode 1 (flat, ``group_by_field=True``): every filter is a single + ``(field, value)`` pair using one of the values from ``ScoreFilterField``. + * Mode 2 (chained, ``group_by_field=False``): each chained filter roots at + ``"Table Group"`` and chains only into ``"Table"`` then ``"Column"``. A + flat ``"Table Group"`` filter is also valid here. + + Errors are collected across every offending entry and reported in one + ``MCPUserError`` so callers see every problem at once rather than chasing + one fix at a time. + + When ``allow_empty=True``, ``None`` / ``[]`` short-circuits to + ``([], True)``. With the default ``allow_empty=False``, empty input raises. + """ + if not raw_filters: + if allow_empty: + return [], True + raise MCPUserError("At least one filter is required.") + + errors: list[str] = [] + for index, filter_ in enumerate(raw_filters): + if not filter_.get("field") or not filter_.get("value"): + errors.append( + f"filters[{index}] must have non-empty `field` and `value`." + ) + continue + errors.extend( + f"filters[{index}] {err}" + for err in _filter_value_errors(filter_["value"], filter_["field"]) + ) + + valid_mode_1_fields = {f.value for f in ScoreFilterField} + has_chain = any( + isinstance(filter_, dict) and filter_.get("others") + for filter_ in raw_filters + ) + + if not has_chain: + parsed: list[dict] = [] + for index, filter_ in enumerate(raw_filters): + if not filter_.get("field") or not filter_.get("value"): + continue + field = filter_["field"] + if field not in valid_mode_1_fields: + valid = ", ".join(sorted(valid_mode_1_fields)) + errors.append( + f"filters[{index}]: `{field}` is not a valid scorecard filter " + f"field. To target specific tables or columns, chain a " + f"`{_CHAIN_ROOT_FIELD}` filter with `others`: " + f'[{{"field": "Table", "value": "..."}}]. ' + f"Valid flat fields: {valid}." + ) + continue + parsed.append({ + "field": SCORE_FILTER_FIELD_TO_COLUMN[ScoreFilterField(field)], + "value": filter_["value"], + }) + if errors: + raise MCPUserError("Invalid filters: " + "; ".join(errors)) + return parsed, True + + parsed_chained: list[dict] = [] + for index, filter_ in enumerate(raw_filters): + if not filter_.get("field") or not filter_.get("value"): + continue + field = filter_["field"] + others = filter_.get("others") or [] + if others and field != _CHAIN_ROOT_FIELD: + errors.append( + f"filters[{index}]: chained filters must root at " + f"`{_CHAIN_ROOT_FIELD}`, got `{field}`." + ) + continue + if not others and field != _CHAIN_ROOT_FIELD: + errors.append( + f"filters[{index}]: when any filter chains tables/columns, " + f"all filters must root at `{_CHAIN_ROOT_FIELD}`. Got `{field}`." + ) + continue + + translated_others: list[dict] = [] + chain_errors = False + for chain_index, chain in enumerate(others): + if not chain.get("field") or not chain.get("value"): + errors.append( + f"filters[{index}].others[{chain_index}] must have " + f"non-empty `field` and `value`." + ) + chain_errors = True + continue + chain_field = chain["field"] + if chain_field not in _CHAIN_LEAF_FIELDS: + errors.append( + f"filters[{index}].others[{chain_index}]: `{chain_field}` " + f"is not a valid chain field. Chains may only descend into " + f"{' or '.join(f'`{f}`' for f in _CHAIN_LEAF_FIELDS)}." + ) + chain_errors = True + continue + value_errors = _filter_value_errors(chain["value"], chain_field) + if value_errors: + errors.extend( + f"filters[{index}].others[{chain_index}] {err}" + for err in value_errors + ) + chain_errors = True + continue + translated_others.append({ + "field": SCORE_CHAIN_LEAF_TO_COLUMN[ScoreChainLeafField(chain_field)], + "value": chain["value"], + }) + + chain_field_values = [c.get("field") for c in others] + if chain_field_values == [ScoreChainLeafField.COLUMN.value]: + errors.append( + f"filters[{index}]: a `Column` chain requires a `Table` step before it." + ) + continue + if ScoreChainLeafField.COLUMN.value in chain_field_values[:-1]: + errors.append( + f"filters[{index}]: `Column` must be the final chain step." + ) + continue + + if chain_errors: + continue + + parsed_chained.append({ + "field": SCORE_FILTER_FIELD_TO_COLUMN[ScoreFilterField.TABLE_GROUP], + "value": filter_["value"], + "others": translated_others, + }) + + if errors: + raise MCPUserError("Invalid filters: " + "; ".join(errors)) + return parsed_chained, False + + +def _row_label(row: dict, category_column: str) -> str | None: + """Compose the display label for a breakdown row. + + For ``column_name`` breakdowns, prefix with the table name so columns with + the same name from different tables don't collapse into one bucket. NULL + category values (e.g. table-scope tests with no column_name) return + ``None`` so the row is skipped — matches ``get_quality_scores``. + """ + if category_column == "column_name": + table = row.get("table_name") + column = row.get("column_name") + if column is None: + return None + return f"{table}.{column}" if table else column + return row.get(category_column) def _format_impact(value: float | None) -> str | None: - # Pass None through so MdDoc renders an em-dash for missing data - # (friendly_score_impact returns the literal "-" for None/0, which - # mismatches the score column's em-dash treatment). + # Pass None through so MdDoc renders an em-dash for missing data — + # friendly_score_impact returns the literal "-" for None/0, which + # mismatches the score column's em-dash treatment. if value is None: return None return friendly_score_impact(value) + + +def _format_criteria_summary(criteria: ScoreDefinitionCriteria | None) -> str: + """Human-readable summary of a scorecard's criteria. + + Two render modes, dispatched by filter shape: + + * Mode 1 (flat filters only): same-field values collapse to ``Label in (a, b)`` + when ``group_by_field=True``; different fields are AND-joined alphabetically + by display label for stable output. + * Mode 2 (any filter has a ``next_filter`` chain): chains are grouped by + ``(root_field, root_value)``; siblings sharing the same chain shape collapse + their leaves into ``in (...)``; root groups are OR-joined. + """ + if criteria is None or not criteria.has_filters(): + return "(no filters)" + + if any(root.next_filter is not None for root in criteria.filters): + return _format_mode_2_summary(criteria) + return _format_mode_1_summary(criteria) + + +def _format_mode_1_summary(criteria: ScoreDefinitionCriteria) -> str: + simple_by_field: dict[str, list[str]] = defaultdict(list) + for root in criteria.filters: + simple_by_field[root.field].append(root.value) + + rendered: list[tuple[str, str]] = [] + for field, values in simple_by_field.items(): + label = _column_label(field) + if len(values) == 1: + rendered.append((label, f"{label} = {values[0]}")) + elif criteria.group_by_field: + rendered.append((label, f"{label} in ({', '.join(values)})")) + else: + joiner = f" {criteria.operand} " + rendered.append((label, joiner.join(f"{label} = {v}" for v in values))) + + rendered.sort(key=lambda p: p[0]) + return " AND ".join(part for _, part in rendered) + + +def _format_mode_2_summary(criteria: ScoreDefinitionCriteria) -> str: + """Render mode-2 (chained) filters with OR semantics and leaf collapse.""" + grouped: dict[tuple[str, str], list[ScoreDefinitionFilter]] = defaultdict(list) + for root in criteria.filters: + grouped[(root.field, root.value)].append(root) + + branches: list[str] = [] + for (root_field, root_value), siblings in grouped.items(): + root_part = f"{_column_label(root_field)} = {root_value}" + chain_paths: list[list[tuple[str, str]]] = [] + for root in siblings: + path: list[tuple[str, str]] = [] + current = root.next_filter + while current is not None: + path.append((current.field, current.value)) + current = current.next_filter + chain_paths.append(path) + + non_empty_paths = [p for p in chain_paths if p] + has_empty = any(not p for p in chain_paths) + + if not non_empty_paths: + branches.append(root_part) + continue + + same_shape = len({tuple(field for field, _ in p) for p in non_empty_paths}) == 1 + if same_shape and not has_empty: + leaf_fields = [field for field, _ in non_empty_paths[0]] + leaf_parts: list[str] = [] + for i, field in enumerate(leaf_fields): + values = [p[i][1] for p in non_empty_paths] + label = _column_label(field) + if len(set(values)) == 1: + leaf_parts.append(f"{label} = {values[0]}") + else: + leaf_parts.append(f"{label} in ({', '.join(values)})") + branches.append(f"{root_part} AND {' AND '.join(leaf_parts)}") + else: + sub_branches: list[str] = [] + for path in chain_paths: + if not path: + sub_branches.append(root_part) + else: + leaves = [f"{_column_label(field)} = {value}" for field, value in path] + sub_branches.append(f"({root_part} AND {' AND '.join(leaves)})") + branches.append(" OR ".join(sub_branches)) + + if len(branches) == 1: + return branches[0] + return " OR ".join(f"({b})" if " AND " in b else b for b in branches) + + +def _column_label(column: str) -> str: + return _COLUMN_TO_LABEL.get(column, column) diff --git a/testgen/ui/scripts/patch_streamlit.py b/testgen/ui/scripts/patch_streamlit.py index 16de43cc..88476737 100644 --- a/testgen/ui/scripts/patch_streamlit.py +++ b/testgen/ui/scripts/patch_streamlit.py @@ -1,4 +1,3 @@ -# ruff: noqa: TRY002 import pathlib import re diff --git a/testgen/ui/views/score_explorer.py b/testgen/ui/views/score_explorer.py index 5eca56ff..e9598ddd 100644 --- a/testgen/ui/views/score_explorer.py +++ b/testgen/ui/views/score_explorer.py @@ -1,19 +1,14 @@ import json import typing -from datetime import datetime from io import BytesIO from typing import ClassVar import pandas as pd import streamlit as st -from testgen.commands.run_refresh_score_cards_results import ( - run_recalculate_score_card, - run_refresh_score_cards_results, -) +from testgen.commands.run_refresh_score_cards_results import save_and_refresh_score_definition from testgen.common.mixpanel_service import MixpanelService from testgen.common.models import with_database_session -from testgen.common.models.profiling_run import ProfilingRun from testgen.common.models.scores import ( Categories, ScoreCategory, @@ -22,7 +17,6 @@ ScoreTypes, SelectedIssue, ) -from testgen.common.models.test_run import TestRun from testgen.common.pii_masking import get_pii_columns, mask_hygiene_detail, mask_profiling_pii from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets.download_dialog import FILE_DATA_TYPE, download_dialog, zip_multi_file_data @@ -373,24 +367,11 @@ def save_score_definition(_) -> None: is_new = True score_definition = ScoreDefinition() - refresh_kwargs = {} if definition_id: is_new = False score_definition = ScoreDefinition.get(definition_id) project_code = score_definition.project_code - if is_new: - latest_run = max( - ProfilingRun.get_latest_run(project_code), - TestRun.get_latest_run(project_code), - key=lambda run: getattr(run, "run_time", datetime.min), - ) - - refresh_kwargs = { - "add_history_entry": True, - "refresh_date": latest_run.run_time if latest_run else None, - } - score_definition.project_code = project_code score_definition.name = name score_definition.total_score = total_score and total_score.lower() == "true" @@ -403,13 +384,9 @@ def save_score_definition(_) -> None: ], group_by_field=not filter_by_columns, ) - score_definition.save() - run_refresh_score_cards_results(definition_id=score_definition.id, **refresh_kwargs) + save_and_refresh_score_definition(score_definition, is_new=is_new) get_all_score_cards.clear() - if not is_new: - run_recalculate_score_card(project_code=project_code, definition_id=score_definition.id) - Router().set_query_params({ "name": None, "total_score": None, diff --git a/tests/unit/commands/test_score_cards.py b/tests/unit/commands/test_score_cards.py index a537eee3..584d185c 100644 --- a/tests/unit/commands/test_score_cards.py +++ b/tests/unit/commands/test_score_cards.py @@ -1,8 +1,14 @@ +from datetime import UTC, datetime +from unittest.mock import patch from uuid import uuid4 import pytest -from testgen.commands.run_refresh_score_cards_results import _score_card_to_results +from testgen.commands.run_refresh_score_cards_results import ( + _score_card_to_results, + save_and_refresh_score_definition, +) +from testgen.common.models.scores import ScoreDefinition, ScoreDefinitionCriteria pytestmark = pytest.mark.unit @@ -80,3 +86,128 @@ def test_none_score_values(): results = _score_card_to_results(card) for result in results: assert result.score is None + + +# --- save_and_refresh_score_definition --- + + +def _fake_definition(project_code: str = "demo") -> ScoreDefinition: + sd = ScoreDefinition() + sd.id = uuid4() + sd.project_code = project_code + sd.name = "Card" + sd.total_score = True + sd.cde_score = False + sd.category = None + sd.criteria = ScoreDefinitionCriteria.from_filters( + [{"field": "table_groups_name", "value": "tg1"}], + group_by_field=True, + ) + return sd + + +def test_save_and_refresh_score_definition_for_existing_card_calls_save_refresh_and_recalculate(): + """is_new=False path: save → refresh → recalculate, all in that order.""" + sd = _fake_definition() + call_order: list[str] = [] + + def record(name): + def _called(*_a, **_kw): + call_order.append(name) + return _called + + with ( + patch.object(ScoreDefinition, "save", autospec=True, side_effect=record("save")), + patch( + "testgen.commands.run_refresh_score_cards_results.run_refresh_score_cards_results", + side_effect=record("refresh"), + ), + patch( + "testgen.commands.run_refresh_score_cards_results.run_recalculate_score_card", + side_effect=record("recalculate"), + ), + ): + save_and_refresh_score_definition(sd, is_new=False) + + assert call_order == ["save", "refresh", "recalculate"] + + +def test_save_and_refresh_score_definition_for_existing_card_passes_refresh_kwargs_for_update(): + """Updates (is_new=False) do NOT pass add_history_entry / refresh_date.""" + sd = _fake_definition() + + with ( + patch.object(ScoreDefinition, "save", autospec=True), + patch( + "testgen.commands.run_refresh_score_cards_results.run_refresh_score_cards_results", + ) as mock_refresh, + patch( + "testgen.commands.run_refresh_score_cards_results.run_recalculate_score_card", + ), + ): + save_and_refresh_score_definition(sd, is_new=False) + + mock_refresh.assert_called_once_with(definition_id=sd.id) + + +def test_save_and_refresh_score_definition_for_new_card_skips_recalculate(): + """is_new=True path: save → refresh with history kwargs; no recalculate.""" + sd = _fake_definition() + + fake_latest = type("Run", (), {"run_time": datetime(2026, 5, 1, tzinfo=UTC)})() + + with ( + patch.object(ScoreDefinition, "save", autospec=True), + patch( + "testgen.commands.run_refresh_score_cards_results.ProfilingRun.get_latest_run", + return_value=fake_latest, + ), + patch( + "testgen.commands.run_refresh_score_cards_results.TestRun.get_latest_run", + return_value=None, + ), + patch( + "testgen.commands.run_refresh_score_cards_results.run_refresh_score_cards_results", + ) as mock_refresh, + patch( + "testgen.commands.run_refresh_score_cards_results.run_recalculate_score_card", + ) as mock_recalc, + ): + save_and_refresh_score_definition(sd, is_new=True) + + mock_refresh.assert_called_once_with( + definition_id=sd.id, + add_history_entry=True, + refresh_date=fake_latest.run_time, + ) + mock_recalc.assert_not_called() + + +def test_save_and_refresh_score_definition_for_new_card_handles_no_runs(): + """When there are no profiling/test runs for the project, refresh_date is None.""" + sd = _fake_definition() + + with ( + patch.object(ScoreDefinition, "save", autospec=True), + patch( + "testgen.commands.run_refresh_score_cards_results.ProfilingRun.get_latest_run", + return_value=None, + ), + patch( + "testgen.commands.run_refresh_score_cards_results.TestRun.get_latest_run", + return_value=None, + ), + patch( + "testgen.commands.run_refresh_score_cards_results.run_refresh_score_cards_results", + ) as mock_refresh, + patch( + "testgen.commands.run_refresh_score_cards_results.run_recalculate_score_card", + ), + ): + save_and_refresh_score_definition(sd, is_new=True) + + mock_refresh.assert_called_once_with( + definition_id=sd.id, + add_history_entry=True, + refresh_date=None, + ) diff --git a/tests/unit/common/models/test_score_definition.py b/tests/unit/common/models/test_score_definition.py index b0acb712..2be46506 100644 --- a/tests/unit/common/models/test_score_definition.py +++ b/tests/unit/common/models/test_score_definition.py @@ -207,6 +207,21 @@ def test_list_with_table_group_targets_empty_project(mock_session_fn): assert ScoreDefinition.list_with_table_group_targets("proj") == [] +@patch("testgen.common.models.scores.get_current_session") +def test_list_with_table_group_targets_dedupes_repeated_names(mock_session_fn): + """A mode-2 scorecard with N chains all rooted at the same table_groups_name + must surface that name only once — otherwise the inventory tool lists the + scorecard once per chain under the same table group.""" + def_id = uuid4() + mock_result = MagicMock() + mock_result.all.return_value = [_row(def_id, "redbox-tables", ["redbox"] * 4)] + mock_session_fn.return_value.execute.return_value = mock_result + + out = ScoreDefinition.list_with_table_group_targets("proj") + + assert out == [(def_id, "redbox-tables", ["redbox"])] + + # --- get_overall_issue_ct --- @@ -289,3 +304,131 @@ def test_get_overall_issue_ct_no_filters_returns_zero(): with patch("testgen.common.models.scores.get_current_session") as mock_session_fn: assert definition.get_overall_issue_ct() == 0 mock_session_fn.return_value.execute.assert_not_called() + + +# --- list_for_project --- + + +def _make_scorecard_orm(name: str, project_code: str = "demo") -> ScoreDefinition: + sd = ScoreDefinition() + sd.id = uuid4() + sd.project_code = project_code + sd.name = name + sd.total_score = True + sd.cde_score = False + return sd + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_for_project_returns_items_and_total(mock_session_fn): + """Returns (rows, total) from scalars().unique() and the count scalar.""" + sd_a = _make_scorecard_orm("Apple") + sd_b = _make_scorecard_orm("Mango") + + session = mock_session_fn.return_value + session.scalar.return_value = 2 + scalars_result = MagicMock() + scalars_result.unique.return_value.all.return_value = [sd_a, sd_b] + session.scalars.return_value = scalars_result + + items, total = ScoreDefinition.list_for_project("demo", page=1, limit=20) + + assert items == [sd_a, sd_b] + assert total == 2 + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_for_project_filters_by_project_code(mock_session_fn): + """The page query's compiled SQL must filter by project_code.""" + session = mock_session_fn.return_value + session.scalar.return_value = 0 + scalars_result = MagicMock() + scalars_result.unique.return_value.all.return_value = [] + session.scalars.return_value = scalars_result + + ScoreDefinition.list_for_project("my-proj") + + page_call = session.scalars.call_args + sql = str(page_call.args[0].compile(compile_kwargs={"literal_binds": True})) + assert "project_code" in sql + assert "'my-proj'" in sql + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_for_project_orders_by_name(mock_session_fn): + """The page query must include ORDER BY name for stable pagination.""" + session = mock_session_fn.return_value + session.scalar.return_value = 0 + scalars_result = MagicMock() + scalars_result.unique.return_value.all.return_value = [] + session.scalars.return_value = scalars_result + + ScoreDefinition.list_for_project("demo") + + sql = str(session.scalars.call_args.args[0].compile(compile_kwargs={"literal_binds": True})) + assert "ORDER BY" in sql.upper() + assert "score_definitions.name" in sql.lower() + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_for_project_applies_offset_and_limit(mock_session_fn): + """page=3, limit=10 → OFFSET 20 LIMIT 10.""" + session = mock_session_fn.return_value + session.scalar.return_value = 100 + scalars_result = MagicMock() + scalars_result.unique.return_value.all.return_value = [] + session.scalars.return_value = scalars_result + + ScoreDefinition.list_for_project("demo", page=3, limit=10) + + sql = str(session.scalars.call_args.args[0].compile(compile_kwargs={"literal_binds": True})) + assert "LIMIT 10" in sql + assert "OFFSET 20" in sql + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_for_project_eager_loads_criteria(mock_session_fn): + """Criteria must be joinedload'd so the rendering loop doesn't fire N+1.""" + session = mock_session_fn.return_value + session.scalar.return_value = 0 + scalars_result = MagicMock() + scalars_result.unique.return_value.all.return_value = [] + session.scalars.return_value = scalars_result + + ScoreDefinition.list_for_project("demo") + + sql = str(session.scalars.call_args.args[0].compile(compile_kwargs={"literal_binds": True})) + # joinedload emits a LEFT OUTER JOIN against the criteria table. + assert "score_definition_criteria" in sql.lower() + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_for_project_count_is_separate_query(mock_session_fn): + """A scalar count query runs alongside the paged scalars query.""" + session = mock_session_fn.return_value + session.scalar.return_value = 7 + scalars_result = MagicMock() + scalars_result.unique.return_value.all.return_value = [] + session.scalars.return_value = scalars_result + + _, total = ScoreDefinition.list_for_project("demo") + + assert total == 7 + assert session.scalar.call_count == 1 + count_sql = str(session.scalar.call_args.args[0].compile(compile_kwargs={"literal_binds": True})) + assert "count(" in count_sql.lower() + assert "'demo'" in count_sql + + +@patch("testgen.common.models.scores.get_current_session") +def test_list_for_project_count_null_returns_zero(mock_session_fn): + """When count() returns NULL on an empty table, normalize to 0.""" + session = mock_session_fn.return_value + session.scalar.return_value = None + scalars_result = MagicMock() + scalars_result.unique.return_value.all.return_value = [] + session.scalars.return_value = scalars_result + + items, total = ScoreDefinition.list_for_project("demo") + assert items == [] + assert total == 0 diff --git a/tests/unit/mcp/test_tools_common.py b/tests/unit/mcp/test_tools_common.py index 9416cc16..aeaffca9 100644 --- a/tests/unit/mcp/test_tools_common.py +++ b/tests/unit/mcp/test_tools_common.py @@ -4,16 +4,21 @@ import pytest from testgen.common.enums import Disposition, ImpactDimension, IssueLikelihood, PiiRisk, QualityDimension +from testgen.common.models.scores import ScoreCategory from testgen.common.models.test_result import TestResultStatus from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.tools.common import ( + SCORE_CATEGORY_ARG_TO_COLUMN, + SCORE_CHAIN_LEAF_TO_COLUMN, SCORE_FILTER_FIELD_TO_COLUMN, SCORE_GROUP_BY_TO_COLUMN, - SCORE_TYPE_TO_INTERNAL, + ScoreCategoryArg, + ScoreChainLeafField, ScoreFilterField, ScoreGroupBy, ScoreType, format_disposition, + parse_category, parse_disposition, parse_impact_dimension, parse_issue_likelihood_list, @@ -566,26 +571,25 @@ def test_parse_score_filter_field_invalid_lists_valid_values(): @pytest.mark.parametrize( - "label,expected_member,expected_internal", + "label,expected_member", [ - ("Combined", ScoreType.COMBINED, "total"), - ("CDE", ScoreType.CDE, "cde"), + ("Total", ScoreType.TOTAL), + ("CDE", ScoreType.CDE), ], ) -def test_parse_score_type_user_labels(label, expected_member, expected_internal): +def test_parse_score_type_user_labels(label, expected_member): member = parse_score_type(label) assert member is expected_member - assert SCORE_TYPE_TO_INTERNAL[member] == expected_internal -@pytest.mark.parametrize("internal", ["total", "cde", "combined"]) +@pytest.mark.parametrize("internal", ["total", "cde"]) def test_parse_score_type_rejects_internal_or_wrong_case(internal): - """The old internal vocabulary (``total``/``cde`` lowercase) must no longer - be accepted on input.""" + """The internal vocabulary (``total``/``cde`` lowercase) must not be + accepted on input; only the canonical user-facing values are.""" with pytest.raises(MCPUserError, match="Invalid score_type") as exc_info: parse_score_type(internal) msg = str(exc_info.value) - assert "Combined" in msg + assert "Total" in msg assert "CDE" in msg @@ -595,3 +599,81 @@ def test_parse_score_type_invalid_lists_valid_values(): msg = str(exc_info.value) for member in ScoreType: assert member.value in msg + + +# --- parse_category --- + + +@pytest.mark.parametrize( + "display_value,expected", + [ + ("Quality Dimension", ScoreCategory.dq_dimension), + ("Impact Dimension", ScoreCategory.impact_dimension), + ("Table Group", ScoreCategory.table_groups_name), + ("Data Source", ScoreCategory.data_source), + ("Data Location", ScoreCategory.data_location), + ("Source System", ScoreCategory.source_system), + ("Source Process", ScoreCategory.source_process), + ("Business Domain", ScoreCategory.business_domain), + ("Stakeholder Group", ScoreCategory.stakeholder_group), + ("Transform Level", ScoreCategory.transform_level), + ("Data Product", ScoreCategory.data_product), + ], +) +def test_parse_category_display_form_returns_column_form_enum(display_value, expected): + """``parse_category`` accepts display-form labels and emits the column-form ``ScoreCategory``.""" + assert parse_category(display_value) is expected + + +def test_parse_category_translation_dict_covers_all_args(): + """Every ``ScoreCategoryArg`` member has a translation to a valid ``ScoreCategory`` column.""" + for arg in ScoreCategoryArg: + column = SCORE_CATEGORY_ARG_TO_COLUMN[arg] + assert ScoreCategory(column) is ScoreCategory(column) # raises if column isn't a valid enum value + + +@pytest.mark.parametrize( + "internal", + [ + "dq_dimension", + "impact_dimension", + "table_groups_name", + "data_source", + "data_location", + "source_system", + "source_process", + "business_domain", + "stakeholder_group", + "transform_level", + "data_product", + ], +) +def test_parse_category_rejects_column_form_input(internal): + """The old column-form values must not be accepted on input — display-form only.""" + with pytest.raises(MCPUserError, match="Invalid category") as exc_info: + parse_category(internal) + msg = str(exc_info.value) + # Error message must list at least one display-form value to guide the caller. + assert "Quality Dimension" in msg + + +def test_parse_category_invalid_lists_display_form_values(): + """An unrelated bad value lists every display-form value in the error message.""" + with pytest.raises(MCPUserError, match="Valid values:") as exc_info: + parse_category("Made Up") + msg = str(exc_info.value) + for member in ScoreCategoryArg: + assert member.value in msg + + +# --- ScoreChainLeafField --- + + +def test_score_chain_leaf_field_values(): + assert ScoreChainLeafField.TABLE.value == "Table" + assert ScoreChainLeafField.COLUMN.value == "Column" + + +def test_score_chain_leaf_to_column_mapping(): + assert SCORE_CHAIN_LEAF_TO_COLUMN[ScoreChainLeafField.TABLE] == "table_name" + assert SCORE_CHAIN_LEAF_TO_COLUMN[ScoreChainLeafField.COLUMN] == "column_name" diff --git a/tests/unit/mcp/test_tools_quality_scores.py b/tests/unit/mcp/test_tools_quality_scores.py index 10aae95c..62dd31c4 100644 --- a/tests/unit/mcp/test_tools_quality_scores.py +++ b/tests/unit/mcp/test_tools_quality_scores.py @@ -3,8 +3,15 @@ import pytest +from testgen.common.models.scores import ( + ScoreCategory, + ScoreDefinition, + ScoreDefinitionBreakdownItem, + ScoreDefinitionCriteria, +) from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.permissions import ProjectPermissions +from testgen.mcp.tools.quality_scores import _format_criteria_summary pytestmark = pytest.mark.unit @@ -36,7 +43,7 @@ def _score_card( def _patch_perms(allowed=("demo",), memberships=None): """Return a patch context manager that injects a ProjectPermissions with given access.""" - memberships = memberships or {code: "role_a" for code in allowed} + memberships = memberships or dict.fromkeys(allowed, "role_a") return patch( "testgen.mcp.permissions._compute_project_permissions", return_value=ProjectPermissions( @@ -80,120 +87,103 @@ def test_invalid_score_type_rejected(db_session_mock): with _patch_perms(), pytest.raises(MCPUserError, match="Invalid score_type") as exc_info: get_quality_scores(project_code="demo", score_type="garbage") msg = str(exc_info.value) - assert "Combined" in msg + assert "Total" in msg assert "CDE" in msg @pytest.mark.parametrize("internal", ["total", "cde"]) def test_internal_score_type_rejected(internal, db_session_mock): - """``total``/``cde`` were the old internal codes — inputs now use ``Combined``/``CDE``.""" + """``total``/``cde`` were the old internal codes — inputs now use ``Total``/``CDE``.""" from testgen.mcp.tools.quality_scores import get_quality_scores with _patch_perms(), pytest.raises(MCPUserError, match="Invalid score_type"): get_quality_scores(project_code="demo", score_type=internal) -def test_invalid_filter_field_rejected(db_session_mock): +def test_project_not_accessible_rejected(db_session_mock): + """A project the user can't view raises MCPResourceNotAccessible-style error.""" from testgen.mcp.tools.quality_scores import get_quality_scores - with _patch_perms(), pytest.raises(MCPUserError, match="Invalid filter field"): - get_quality_scores( - project_code="demo", - filters=[{"field": "not_a_field", "value": "x"}], - ) - + with _patch_perms(allowed=("only_this",)), pytest.raises(MCPResourceNotAccessible, match="forbidden_proj"): + get_quality_scores(project_code="forbidden_proj") -def test_internal_filter_field_rejected(db_session_mock): - """Old internal column name as filter field is no longer accepted.""" - from testgen.mcp.tools.quality_scores import get_quality_scores - with _patch_perms(), pytest.raises(MCPUserError, match="Invalid filter field"): - get_quality_scores( - project_code="demo", - filters=[{"field": "business_domain", "value": "Finance"}], - ) +# --- Score-type → model-call mapping --- -def test_quality_dimension_rejected_as_filter_field(db_session_mock): - """Quality Dimension is a group_by, not a filter field — must reject with a hint.""" +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_default_overall_shows_both_total_and_cde(mock_definition_cls, db_session_mock): + """score_type omitted → both Total and CDE Score lines are rendered.""" from testgen.mcp.tools.quality_scores import get_quality_scores - with _patch_perms(), pytest.raises(MCPUserError, match="Quality Dimension") as exc_info: - get_quality_scores( - project_code="demo", - filters=[{"field": "Quality Dimension", "value": "Accuracy"}], - ) - assert "group_by" in str(exc_info.value) - - -def test_impact_dimension_rejected_as_filter_field(db_session_mock): - from testgen.mcp.tools.quality_scores import get_quality_scores + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.93, cde_score=0.81) + mock_definition_cls.return_value = mock_definition - with _patch_perms(), pytest.raises(MCPUserError, match="Impact Dimension") as exc_info: - get_quality_scores( + with _patch_perms(): + out = get_quality_scores( project_code="demo", - filters=[{"field": "Impact Dimension", "value": "Workflow"}], + filters=[{"field": "Business Domain", "value": "Finance"}], ) - assert "group_by" in str(exc_info.value) - -def test_filter_value_with_forbidden_chars_rejected(db_session_mock): - """SQL-injection probe — values with single quotes or semicolons must be rejected.""" - from testgen.mcp.tools.quality_scores import get_quality_scores - - with _patch_perms(), pytest.raises(MCPUserError, match="forbidden"): - get_quality_scores( - project_code="demo", - filters=[{"field": "Business Domain", "value": "O';DROP TABLE"}], - ) + assert "Total Score" in out + assert "93" in out + assert "CDE Score" in out + assert "81" in out -def test_filter_value_oversize_rejected(db_session_mock): +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_total_overall_shows_only_total(mock_definition_cls, db_session_mock): + """score_type='Total' renders only the Total Score line.""" from testgen.mcp.tools.quality_scores import get_quality_scores - with _patch_perms(), pytest.raises(MCPUserError, match="too long"): - get_quality_scores( + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.93, cde_score=None) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( project_code="demo", - filters=[{"field": "Business Domain", "value": "x" * 257}], + score_type="Total", + filters=[{"field": "Business Domain", "value": "Finance"}], ) - -def test_multiple_filter_problems_listed_at_once(db_session_mock): - """When several filter entries are bad, the error lists every offender.""" - from testgen.mcp.tools.quality_scores import get_quality_scores - - bad_filters = [ - {"field": "Quality Dimension", "value": "Accuracy"}, # not a filter field - {"field": "Business Domain", "value": "x';--"}, # bad chars - {"field": "Data Source", "value": ""}, # empty value - ] - with _patch_perms(), pytest.raises(MCPUserError) as exc_info: - get_quality_scores(project_code="demo", filters=bad_filters) - - msg = str(exc_info.value) - assert "Quality Dimension" in msg - assert "Business Domain" in msg - assert "Data Source" in msg + assert "Total Score" in out + assert "93" in out + assert "CDE Score" not in out -def test_project_not_accessible_rejected(db_session_mock): - """A project the user can't view raises MCPResourceNotAccessible-style error.""" +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_cde_overall_shows_only_cde(mock_definition_cls, db_session_mock): + """score_type='CDE' renders only the CDE Score line.""" from testgen.mcp.tools.quality_scores import get_quality_scores - with _patch_perms(allowed=("only_this",)), pytest.raises(MCPResourceNotAccessible, match="forbidden_proj"): - get_quality_scores(project_code="forbidden_proj") + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=None, cde_score=0.81) + mock_definition_cls.return_value = mock_definition + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + score_type="CDE", + filters=[{"field": "Business Domain", "value": "Finance"}], + ) -# --- Score-type → model-call mapping --- + assert "CDE Score" in out + assert "81" in out + assert "Total Score" not in out @patch("testgen.mcp.tools.quality_scores.ScoreDefinition") -def test_default_overall_shows_both_combined_and_cde(mock_definition_cls, db_session_mock): - """score_type omitted → both Combined and CDE Score lines are rendered.""" +def test_default_overall_includes_profiling_and_testing(mock_definition_cls, db_session_mock): + """score_type omitted → overall block surfaces Total, CDE, Profiling, + and Testing — same set the UI's score-card shows when Total is enabled.""" from testgen.mcp.tools.quality_scores import get_quality_scores mock_definition = MagicMock() - mock_definition.as_score_card.return_value = _score_card(score=0.93, cde_score=0.81) + mock_definition.as_score_card.return_value = _score_card( + score=0.93, cde_score=0.81, profiling_score=0.95, testing_score=0.85, + ) mock_definition_cls.return_value = mock_definition with _patch_perms(): @@ -202,40 +192,48 @@ def test_default_overall_shows_both_combined_and_cde(mock_definition_cls, db_ses filters=[{"field": "Business Domain", "value": "Finance"}], ) - assert "Combined Score" in out - assert "93" in out + assert "Total Score" in out assert "CDE Score" in out - assert "81" in out + assert "Profiling Score" in out + assert "Testing Score" in out + assert "95" in out + assert "85" in out @patch("testgen.mcp.tools.quality_scores.ScoreDefinition") -def test_combined_overall_shows_only_combined(mock_definition_cls, db_session_mock): - """score_type='Combined' renders only the Combined Score line.""" +def test_total_overall_includes_profiling_and_testing(mock_definition_cls, db_session_mock): + """score_type='Total' → Total + Profiling + Testing render; CDE omitted.""" from testgen.mcp.tools.quality_scores import get_quality_scores mock_definition = MagicMock() - mock_definition.as_score_card.return_value = _score_card(score=0.93, cde_score=None) + mock_definition.as_score_card.return_value = _score_card( + score=0.93, cde_score=None, profiling_score=0.95, testing_score=0.85, + ) mock_definition_cls.return_value = mock_definition with _patch_perms(): out = get_quality_scores( project_code="demo", - score_type="Combined", + score_type="Total", filters=[{"field": "Business Domain", "value": "Finance"}], ) - assert "Combined Score" in out - assert "93" in out + assert "Total Score" in out + assert "Profiling Score" in out + assert "Testing Score" in out assert "CDE Score" not in out @patch("testgen.mcp.tools.quality_scores.ScoreDefinition") -def test_cde_overall_shows_only_cde(mock_definition_cls, db_session_mock): - """score_type='CDE' renders only the CDE Score line.""" +def test_cde_overall_omits_profiling_and_testing(mock_definition_cls, db_session_mock): + """score_type='CDE' → Profiling/Testing must not appear even if the score + card returns values for them (matches UI's Total-only gating).""" from testgen.mcp.tools.quality_scores import get_quality_scores mock_definition = MagicMock() - mock_definition.as_score_card.return_value = _score_card(score=None, cde_score=0.81) + mock_definition.as_score_card.return_value = _score_card( + score=None, cde_score=0.81, profiling_score=0.95, testing_score=0.85, + ) mock_definition_cls.return_value = mock_definition with _patch_perms(): @@ -246,13 +244,14 @@ def test_cde_overall_shows_only_cde(mock_definition_cls, db_session_mock): ) assert "CDE Score" in out - assert "81" in out - assert "Combined Score" not in out + assert "Total Score" not in out + assert "Profiling Score" not in out + assert "Testing Score" not in out @patch("testgen.mcp.tools.quality_scores.ScoreDefinition") -def test_combined_grouped_uses_breakdown(mock_definition_cls, db_session_mock): - """score_type='Combined' + group_by sources per-category rows from breakdown. +def test_total_grouped_uses_breakdown(mock_definition_cls, db_session_mock): + """score_type='Total' + group_by sources per-category rows from breakdown. Per-category output always includes Impact (matching the Score Explorer UI), so the tool reads from get_score_card_breakdown rather than card.categories. @@ -270,7 +269,7 @@ def test_combined_grouped_uses_breakdown(mock_definition_cls, db_session_mock): with _patch_perms(): out = get_quality_scores( project_code="demo", - score_type="Combined", + score_type="Total", group_by="Business Domain", filters=[{"field": "Data Source", "value": "warehouse"}], include_impact=True, @@ -279,7 +278,7 @@ def test_combined_grouped_uses_breakdown(mock_definition_cls, db_session_mock): mock_definition.get_score_card_breakdown.assert_called_once_with("score", "business_domain") assert "Finance" in out assert "Marketing" in out - assert "Impact on Combined Score" in out + assert "Impact on Total Score" in out assert "Impact on CDE Score" not in out @@ -307,12 +306,12 @@ def test_cde_grouped_uses_breakdown(mock_definition_cls, db_session_mock): mock_definition.get_score_card_breakdown.assert_called_once_with("cde_score", "business_domain") assert "Finance" in out assert "Impact on CDE Score" in out - assert "Impact on Combined Score" not in out + assert "Impact on Total Score" not in out @patch("testgen.mcp.tools.quality_scores.ScoreDefinition") def test_default_grouped_renders_both_score_columns(mock_definition_cls, db_session_mock): - """score_type omitted + group_by → table has Combined + CDE columns and + """score_type omitted + group_by → table has Total + CDE columns and Impact columns for both, populated from two breakdown calls. """ from testgen.mcp.tools.quality_scores import get_quality_scores @@ -348,9 +347,9 @@ def test_default_grouped_renders_both_score_columns(mock_definition_cls, db_sess call_keys = {c.args[0] for c in mock_definition.get_score_card_breakdown.call_args_list} assert call_keys == {"score", "cde_score"} - assert "Combined Score" in out + assert "Total Score" in out assert "CDE Score" in out - assert "Impact on Combined Score" in out + assert "Impact on Total Score" in out assert "Impact on CDE Score" in out assert "Finance" in out assert "Marketing" in out @@ -383,14 +382,14 @@ def test_include_impact_default_false_omits_impact_columns(mock_definition_cls, ) assert "Finance" in out - assert "Combined Score" in out + assert "Total Score" in out assert "CDE Score" in out assert "Impact" not in out @patch("testgen.mcp.tools.quality_scores.ScoreDefinition") -def test_include_impact_false_combined_only_omits_impact_column(mock_definition_cls, db_session_mock): - """Combined-only + default include_impact=False → no impact column.""" +def test_include_impact_false_total_only_omits_impact_column(mock_definition_cls, db_session_mock): + """Total-only + default include_impact=False → no impact column.""" from testgen.mcp.tools.quality_scores import get_quality_scores mock_definition = MagicMock() @@ -403,13 +402,13 @@ def test_include_impact_false_combined_only_omits_impact_column(mock_definition_ with _patch_perms(): out = get_quality_scores( project_code="demo", - score_type="Combined", + score_type="Total", group_by="Business Domain", filters=[{"field": "Data Source", "value": "wh"}], ) assert "Finance" in out - assert "Combined Score" in out + assert "Total Score" in out assert "Impact" not in out @@ -488,8 +487,8 @@ def test_include_issue_ct_overall_calls_get_overall_issue_ct(mock_definition_cls @patch("testgen.mcp.tools.quality_scores.ScoreDefinition") -def test_include_issue_ct_grouped_combined_uses_simple_label(mock_definition_cls, db_session_mock): - """grouped + Combined + include_issue_ct: single 'Issue Count' column header.""" +def test_include_issue_ct_grouped_total_uses_simple_label(mock_definition_cls, db_session_mock): + """grouped + Total + include_issue_ct: single 'Issue Count' column header.""" from testgen.mcp.tools.quality_scores import get_quality_scores mock_definition = MagicMock() @@ -502,7 +501,7 @@ def test_include_issue_ct_grouped_combined_uses_simple_label(mock_definition_cls with _patch_perms(): out = get_quality_scores( project_code="demo", - score_type="Combined", + score_type="Total", group_by="Business Domain", include_issue_ct=True, filters=[{"field": "Data Source", "value": "wh"}], @@ -512,7 +511,7 @@ def test_include_issue_ct_grouped_combined_uses_simple_label(mock_definition_cls assert "Finance" in out assert "7" in out assert "Issue Count" in out - assert "Issue Count (Combined)" not in out + assert "Issue Count (Total)" not in out assert "Issue Count (CDE)" not in out @@ -541,13 +540,13 @@ def test_include_issue_ct_grouped_cde_uses_simple_label(mock_definition_cls, db_ assert "Finance" in out assert "3" in out assert "Issue Count" in out - assert "Issue Count (Combined)" not in out + assert "Issue Count (Total)" not in out assert "Issue Count (CDE)" not in out @patch("testgen.mcp.tools.quality_scores.ScoreDefinition") def test_include_issue_ct_grouped_default_uses_parenthetical_labels(mock_definition_cls, db_session_mock): - """grouped + score_type unset + include_issue_ct: separate Combined / CDE + """grouped + score_type unset + include_issue_ct: separate Total / CDE issue-count columns, and both Impact columns.""" from testgen.mcp.tools.quality_scores import get_quality_scores @@ -572,12 +571,12 @@ def test_include_issue_ct_grouped_default_uses_parenthetical_labels(mock_definit ) assert mock_definition.get_score_card_breakdown.call_count == 2 - assert "Issue Count (Combined)" in out + assert "Issue Count (Total)" in out assert "Issue Count (CDE)" in out - assert "Impact on Combined Score" in out + assert "Impact on Total Score" in out assert "Impact on CDE Score" in out # Both per-category issue counts must appear, not just one - assert "7" in out # combined count + assert "7" in out # total count assert "3" in out # cde count @@ -712,7 +711,7 @@ def test_grouped_row_cap_truncates_and_footers(mock_definition_cls, db_session_m with _patch_perms(): out = get_quality_scores( project_code="demo", - score_type="Combined", + score_type="Total", group_by="Business Domain", filters=[{"field": "Data Source", "value": "wh"}], ) @@ -721,6 +720,57 @@ def test_grouped_row_cap_truncates_and_footers(mock_definition_cls, db_session_m assert str(_ROW_CAP + 10) in out +# --- Empty-breakdown messaging differs based on whether filters were supplied --- + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_grouped_empty_breakdown_with_filters_renders_filter_matched(mock_definition_cls, db_session_mock): + """User-supplied filter that returns no breakdown rows surfaces 'Filter matched no data.'""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition.get_score_card_breakdown.return_value = [] + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + group_by="Business Domain", + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + + assert "Filter matched no data" in out + assert "No category data" not in out + + +@patch("testgen.mcp.tools.quality_scores.TableGroup") +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_grouped_empty_breakdown_without_filters_renders_no_category_data( + mock_definition_cls, mock_tg_cls, db_session_mock, +): + """Unfiltered project with no breakdown rows keeps the generic 'No category data.' message.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + tg = MagicMock() + tg.table_groups_name = "orders" + mock_tg_cls.select_minimal_where.return_value = [tg] + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition.get_score_card_breakdown.return_value = [] + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + out = get_quality_scores( + project_code="demo", + group_by="Business Domain", + ) + + assert "No category data" in out + assert "Filter matched no data" not in out + + # --- Transient definition is never persisted --- @@ -740,3 +790,1989 @@ def test_transient_definition_never_persisted(mock_definition_cls, db_session_mo ) mock_definition.save.assert_not_called() + + +# ============================================================ +# Scorecard tools — merged in from test_tools_scorecards.py +# ============================================================ + +def _criteria(filters: list[dict], group_by_field: bool = True) -> ScoreDefinitionCriteria: + return ScoreDefinitionCriteria.from_filters(filters, group_by_field=group_by_field) + + + + +def _fake_definition( + name: str, + *, + project_code: str = "demo", + total: bool = True, + cde: bool = False, + category: ScoreCategory | None = None, + filters: list[dict] | None = None, + group_by_field: bool = True, + score: float | None = 0.95, + cde_value: float | None = 0.90, +) -> ScoreDefinition: + sd = ScoreDefinition() + sd.id = uuid4() + sd.project_code = project_code + sd.name = name + sd.total_score = total + sd.cde_score = cde + sd.category = category + sd.criteria = ScoreDefinitionCriteria.from_filters( + filters or [{"field": "table_groups_name", "value": "tg1"}], + group_by_field=group_by_field, + ) + sd._fake_card = {"score": score, "cde_score": cde_value} + return sd + + +@pytest.fixture +def patch_card(monkeypatch): + """Route as_cached_score_card to the stub stored on each fake definition.""" + def _cached(self, include_definition: bool = False): + return self._fake_card + monkeypatch.setattr(ScoreDefinition, "as_cached_score_card", _cached) + + +def _patch_list(items, total): + return patch.object(ScoreDefinition, "list_for_project", return_value=(items, total)) + + +# --- _format_criteria_summary --- + + +def test_format_criteria_summary_none(): + assert _format_criteria_summary(None) == "(no filters)" + + +def test_format_criteria_summary_empty(): + criteria = ScoreDefinitionCriteria(operand="AND", filters=[], group_by_field=True) + assert _format_criteria_summary(criteria) == "(no filters)" + + +def test_format_criteria_summary_single_filter_uses_display_label(): + criteria = _criteria([{"field": "table_groups_name", "value": "sales"}]) + assert _format_criteria_summary(criteria) == "Table Group = sales" + + +def test_format_criteria_summary_or_within_field(): + """group_by_field=True with multiple roots on the same field renders as `in (...)`.""" + criteria = _criteria([ + {"field": "table_groups_name", "value": "sales"}, + {"field": "table_groups_name", "value": "marketing"}, + ]) + assert _format_criteria_summary(criteria) == "Table Group in (sales, marketing)" + + +def test_format_criteria_summary_and_across_fields(): + criteria = _criteria([ + {"field": "table_groups_name", "value": "sales"}, + {"field": "business_domain", "value": "Finance"}, + ]) + # Ordering is alphabetical by display label for stable output. + assert _format_criteria_summary(criteria) == "Business Domain = Finance AND Table Group = sales" + + +def test_format_criteria_summary_chained_next_filter(): + """A root filter with `others` becomes a next_filter AND-chain inside the root.""" + criteria = ScoreDefinitionCriteria.from_filters( + [{ + "field": "table_groups_name", + "value": "sales", + "others": [{"field": "business_domain", "value": "Finance"}], + }], + group_by_field=True, + ) + summary = _format_criteria_summary(criteria) + assert "Table Group = sales" in summary + assert "Business Domain = Finance" in summary + assert " AND " in summary + + +def test_format_criteria_summary_unknown_field_falls_back_to_raw_column(): + criteria = _criteria([{"field": "made_up_column", "value": "x"}]) + assert _format_criteria_summary(criteria) == "made_up_column = x" + + +def test_format_criteria_summary_mode_2_chained_uses_table_label(): + """A chain into table_name renders the user-facing "Table" label, not the column name.""" + criteria = ScoreDefinitionCriteria.from_filters( + [{ + "field": "table_groups_name", + "value": "redbox", + "others": [{"field": "table_name", "value": "accounts"}], + }], + group_by_field=False, + ) + summary = _format_criteria_summary(criteria) + assert "Table Group = redbox" in summary + assert "Table = accounts" in summary + assert "table_name" not in summary + + +def test_format_criteria_summary_mode_2_chained_uses_column_label(): + criteria = ScoreDefinitionCriteria.from_filters( + [{ + "field": "table_groups_name", + "value": "redbox", + "others": [ + {"field": "table_name", "value": "accounts"}, + {"field": "column_name", "value": "id"}, + ], + }], + group_by_field=False, + ) + summary = _format_criteria_summary(criteria) + assert "Column = id" in summary + assert "column_name" not in summary + + +def test_format_criteria_summary_mode_2_sibling_chains_collapse_to_in(): + """Chains sharing the same root (table_groups_name=X) collapse to `Table in (...)`.""" + criteria = ScoreDefinitionCriteria.from_filters( + [ + {"field": "table_groups_name", "value": "redbox", + "others": [{"field": "table_name", "value": "a"}]}, + {"field": "table_groups_name", "value": "redbox", + "others": [{"field": "table_name", "value": "b"}]}, + {"field": "table_groups_name", "value": "redbox", + "others": [{"field": "table_name", "value": "c"}]}, + ], + group_by_field=False, + ) + summary = _format_criteria_summary(criteria) + assert summary == "Table Group = redbox AND Table in (a, b, c)" + + +def test_format_criteria_summary_mode_2_different_roots_or_joined(): + """Chains with different table_groups_name roots are OR-joined (not AND-joined).""" + criteria = ScoreDefinitionCriteria.from_filters( + [ + {"field": "table_groups_name", "value": "redbox", + "others": [{"field": "table_name", "value": "a"}]}, + {"field": "table_groups_name", "value": "sales", + "others": [{"field": "table_name", "value": "b"}]}, + ], + group_by_field=False, + ) + summary = _format_criteria_summary(criteria) + assert " OR " in summary + assert " AND " not in summary.replace(" AND Table = ", "") # AND only inside a chain + assert "redbox" in summary + assert "sales" in summary + + +# --- list_scorecards tool --- + + +def test_list_scorecards_requires_view_access(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import list_scorecards + + with _patch_perms(allowed=("only_this",)), pytest.raises( + MCPResourceNotAccessible, match="forbidden_proj" + ): + list_scorecards("forbidden_proj") + + +def test_list_scorecards_empty_renders_friendly_message(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import list_scorecards + + with _patch_perms(), _patch_list([], 0): + out = list_scorecards("demo") + assert "Scorecards in Project `demo`" in out + assert "_No scorecards configured._" in out + + +def test_list_scorecards_renders_total_and_cde(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import list_scorecards + + items = [ + _fake_definition( + "Sales Quality", + total=True, + cde=True, + category=ScoreCategory.dq_dimension, + filters=[{"field": "table_groups_name", "value": "sales"}], + score=0.95, + cde_value=0.90, + ), + ] + with _patch_perms(), _patch_list(items, 1): + out = list_scorecards("demo") + assert "Sales Quality" in out + assert "Total Score" in out + assert "CDE Score" in out + assert "Quality Dimension" in out # display label for dq_dimension + assert "Table Group = sales" in out + assert "0.95" in out or "95" in out + assert "0.90" in out or "90" in out + + +def test_list_scorecards_hides_cde_when_disabled(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import list_scorecards + + items = [_fake_definition("Only Total", total=True, cde=False, cde_value=None)] + with _patch_perms(), _patch_list(items, 1): + out = list_scorecards("demo") + assert "Total Score" in out + assert "CDE Score" not in out + + +def test_list_scorecards_hides_total_when_disabled(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import list_scorecards + + items = [_fake_definition("CDE Only", total=False, cde=True, score=None, cde_value=0.50)] + with _patch_perms(), _patch_list(items, 1): + out = list_scorecards("demo") + assert "CDE Score" in out + assert "Total Score" not in out + + +def test_list_scorecards_includes_profiling_and_testing_when_total_enabled(db_session_mock, patch_card): + """When total_score is enabled, the per-scorecard block surfaces Profiling + Score and Testing Score — matching the UI's score-card and get_scorecard.""" + from testgen.mcp.tools.quality_scores import list_scorecards + + sd = _fake_definition( + "Full Card", + total=True, + cde=True, + score=0.925, + cde_value=0.880, + ) + sd._fake_card.update({"profiling_score": 0.950, "testing_score": 0.900}) + with _patch_perms(), _patch_list([sd], 1): + out = list_scorecards("demo") + assert "Profiling Score" in out + assert "Testing Score" in out + # friendly_score scales by 100 and rounds to 1 decimal. + assert "95" in out + assert "90" in out + + +def test_list_scorecards_omits_profiling_and_testing_for_cde_only_scorecard(db_session_mock, patch_card): + """When total_score is disabled, Profiling/Testing must not appear even + though as_cached_score_card may return values for them.""" + from testgen.mcp.tools.quality_scores import list_scorecards + + sd = _fake_definition("CDE Only", total=False, cde=True, score=None, cde_value=0.50) + sd._fake_card.update({"profiling_score": 0.7, "testing_score": 0.8}) + with _patch_perms(), _patch_list([sd], 1): + out = list_scorecards("demo") + assert "CDE Score" in out + assert "Profiling Score" not in out + assert "Testing Score" not in out + + +def test_list_scorecards_omits_breakdown_when_no_category(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import list_scorecards + + items = [_fake_definition("Plain", category=None)] + with _patch_perms(), _patch_list(items, 1): + out = list_scorecards("demo") + assert "Category" not in out + + +def test_list_scorecards_emits_pagination_info_and_footer(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import list_scorecards + + items = [_fake_definition(f"Card {i}") for i in range(3)] + with _patch_perms(), _patch_list(items, 25): + out = list_scorecards("demo", page=1, limit=3) + # format_page_info emits an en-dash (\u2013) between start and end. + assert "Showing 1\u20133 of 25" in out + assert "Use `page=2` for more" in out + + +def test_list_scorecards_empty_page_past_end(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import list_scorecards + + with _patch_perms(), _patch_list([], 3): + out = list_scorecards("demo", page=5, limit=10) + # No-scorecards-on-page message references current page + total + assert "page 5" in out + assert "total: 3" in out + + +@pytest.mark.parametrize("page,limit", [(0, 10), (1, 0), (1, 101)]) +def test_list_scorecards_rejects_invalid_pagination(db_session_mock, patch_card, page, limit): + from testgen.mcp.tools.quality_scores import list_scorecards + + with _patch_perms(), pytest.raises(MCPUserError): + list_scorecards("demo", page=page, limit=limit) + + +def test_list_scorecards_renders_filter_chain(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import list_scorecards + + items = [_fake_definition( + "Multi-filter", + filters=[ + {"field": "table_groups_name", "value": "sales"}, + {"field": "business_domain", "value": "Finance"}, + ], + )] + with _patch_perms(), _patch_list(items, 1): + out = list_scorecards("demo") + assert "Business Domain = Finance" in out + assert "Table Group = sales" in out + assert " AND " in out + + +# --- get_scorecard tool --- + + +def _fake_breakdown_item( + *, + category: str, + score_type: str, + field_values: dict, + impact: float = 0.5, + score: float = 0.85, + issue_ct: int = 3, +): + """Build a fake `ScoreDefinitionBreakdownItem`-like object exposing ``.to_dict()``. + + Matches the shape produced by the real ``to_dict`` — category-specific fields + plus ``impact``, ``score``, ``issue_ct``. + """ + item = MagicMock(spec=ScoreDefinitionBreakdownItem) + item.category = category + item.score_type = score_type + item.to_dict = MagicMock(return_value={ + **field_values, + "impact": impact, + "score": score, + "issue_ct": issue_ct, + }) + return item + + +def _patch_get(definition): + return patch.object(ScoreDefinition, "get", return_value=definition) + + +def _patch_breakdown(items): + return patch.object(ScoreDefinitionBreakdownItem, "filter", return_value=items) + + +def _patch_breakdown_by_score_type(total, cde): + """Return different breakdown rows depending on the requested ``score_type``.""" + def _filter(*, definition_id, category, score_type): + return total if score_type == "score" else cde + return patch.object(ScoreDefinitionBreakdownItem, "filter", side_effect=_filter) + + +def test_get_scorecard_rejects_invalid_uuid(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + with _patch_perms(), pytest.raises(MCPUserError, match="not a valid UUID"): + get_scorecard("not-a-uuid") + + +def test_get_scorecard_unknown_id_returns_not_accessible(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + missing_id = str(uuid4()) + with _patch_perms(), _patch_get(None), pytest.raises( + MCPResourceNotAccessible, match=missing_id + ): + get_scorecard(missing_id) + + +def test_get_scorecard_forbidden_project_returns_not_accessible(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + sd = _fake_definition("Other-project card", project_code="forbidden_proj") + with _patch_perms(allowed=("demo",)), _patch_get(sd), pytest.raises( + MCPResourceNotAccessible, match=str(sd.id) + ): + get_scorecard(str(sd.id)) + + +def test_get_scorecard_renders_overall_scores(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + sd = _fake_definition( + "Sales Quality", + total=True, + cde=True, + category=ScoreCategory.dq_dimension, + score=0.95, + cde_value=0.90, + ) + sd._fake_card.update({"profiling_score": 0.88, "testing_score": 0.91}) + with _patch_perms(), _patch_get(sd), _patch_breakdown([]): + out = get_scorecard(str(sd.id)) + assert "Sales Quality" in out + assert "Total Score" in out + assert "CDE Score" in out + assert "Profiling Score" in out + assert "Testing Score" in out + # Filter summary is preserved from list_scorecards behavior. + assert "Table Group = tg1" in out + + +def test_get_scorecard_hides_total_when_disabled(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + sd = _fake_definition( + "CDE-Only Card", + total=False, + cde=True, + category=None, + score=None, + cde_value=0.5, + ) + with _patch_perms(), _patch_get(sd), _patch_breakdown([]): + out = get_scorecard(str(sd.id)) + assert "CDE Score" in out + assert "Total Score" not in out + # Profiling/Testing are components of the Total score — should be hidden too. + assert "Profiling Score" not in out + assert "Testing Score" not in out + + +def test_get_scorecard_hides_cde_when_disabled(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + sd = _fake_definition( + "Total-Only Card", + total=True, + cde=False, + category=None, + cde_value=None, + ) + sd._fake_card.update({"profiling_score": 0.7, "testing_score": 0.8}) + with _patch_perms(), _patch_get(sd), _patch_breakdown([]): + out = get_scorecard(str(sd.id)) + assert "Total Score" in out + assert "Profiling Score" in out + assert "Testing Score" in out + assert "CDE Score" not in out + + +def test_get_scorecard_omits_breakdown_when_no_category(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + sd = _fake_definition("Plain", category=None) + sd._fake_card.update({"profiling_score": 0.7, "testing_score": 0.8}) + with _patch_perms(), _patch_get(sd), _patch_breakdown([]): + out = get_scorecard(str(sd.id)) + assert "Category" not in out + + +def test_get_scorecard_renders_breakdown_wide_table(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + sd = _fake_definition( + "Wide breakdown", + total=True, + cde=True, + category=ScoreCategory.dq_dimension, + ) + sd._fake_card.update({"profiling_score": 0.7, "testing_score": 0.8}) + + total_items = [ + _fake_breakdown_item( + category="dq_dimension", + score_type="score", + field_values={"dq_dimension": "Accuracy"}, + impact=0.4, + score=0.6, + issue_ct=10, + ), + ] + cde_items = [ + _fake_breakdown_item( + category="dq_dimension", + score_type="cde_score", + field_values={"dq_dimension": "Accuracy"}, + impact=0.3, + score=0.7, + issue_ct=5, + ), + ] + + with ( + _patch_perms(), + _patch_get(sd), + _patch_breakdown_by_score_type(total_items, cde_items), + ): + out = get_scorecard(str(sd.id)) + assert "Breakdown by Quality Dimension" in out + assert "Accuracy" in out + # Both score types in headers — parenthetical disambiguates which column is which. + assert "Issue Count (Total)" in out + assert "Issue Count (CDE)" in out + assert "Impact on Total Score" in out + assert "Impact on CDE Score" in out + + +def test_get_scorecard_breakdown_single_score_type(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + sd = _fake_definition( + "Single-type breakdown", + total=True, + cde=False, + category=ScoreCategory.business_domain, + cde_value=None, + ) + sd._fake_card.update({"profiling_score": 0.7, "testing_score": 0.8}) + + items = [ + _fake_breakdown_item( + category="business_domain", + score_type="score", + field_values={"business_domain": "Finance"}, + ), + ] + with _patch_perms(), _patch_get(sd), _patch_breakdown(items): + out = get_scorecard(str(sd.id)) + assert "Breakdown by Business Domain" in out + assert "Finance" in out + # When only one type is enabled, headers drop the parenthetical (mirrors get_quality_scores). + assert "Issue Count (Total)" not in out + assert "Issue Count (CDE)" not in out + + +def test_get_scorecard_breakdown_caps_at_100(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + sd = _fake_definition( + "Many rows", + total=True, + cde=False, + category=ScoreCategory.business_domain, + cde_value=None, + ) + sd._fake_card.update({"profiling_score": 0.7, "testing_score": 0.8}) + + items = [ + _fake_breakdown_item( + category="business_domain", + score_type="score", + field_values={"business_domain": f"Domain {i}"}, + impact=0.5 - 0.001 * i, + ) + for i in range(101) + ] + with _patch_perms(), _patch_get(sd), _patch_breakdown(items): + out = get_scorecard(str(sd.id)) + assert "Showing top 100 of 101" in out + + +def test_get_scorecard_breakdown_empty(db_session_mock, patch_card): + from testgen.mcp.tools.quality_scores import get_scorecard + + sd = _fake_definition( + "No data", + total=True, + cde=False, + category=ScoreCategory.dq_dimension, + cde_value=None, + ) + sd._fake_card.update({"profiling_score": 0.7, "testing_score": 0.8}) + with _patch_perms(), _patch_get(sd), _patch_breakdown([]): + out = get_scorecard(str(sd.id)) + assert "Breakdown by Quality Dimension" in out + assert "_No breakdown data._" in out + + +# --- delete_scorecard tool --- + + +def test_delete_scorecard_rejects_invalid_uuid(db_session_mock): + from testgen.mcp.tools.quality_scores import delete_scorecard + + with _patch_perms(), pytest.raises(MCPUserError, match="not a valid UUID"): + delete_scorecard("not-a-uuid") + + +def test_delete_scorecard_unknown_id_returns_not_accessible(db_session_mock): + from testgen.mcp.tools.quality_scores import delete_scorecard + + missing_id = str(uuid4()) + with _patch_perms(), _patch_get(None), pytest.raises( + MCPResourceNotAccessible, match=missing_id + ): + delete_scorecard(missing_id) + + +def test_delete_scorecard_forbidden_project_does_not_call_delete(db_session_mock): + from testgen.mcp.tools.quality_scores import delete_scorecard + + sd = _fake_definition("Other-project card", project_code="forbidden_proj") + with ( + _patch_perms(allowed=("demo",)), + _patch_get(sd), + patch.object(ScoreDefinition, "delete") as mock_delete, + pytest.raises(MCPResourceNotAccessible, match=str(sd.id)), + ): + delete_scorecard(str(sd.id)) + assert mock_delete.called is False + + +def test_delete_scorecard_calls_model_delete(db_session_mock): + from testgen.mcp.tools.quality_scores import delete_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + patch.object(ScoreDefinition, "delete") as mock_delete, + ): + delete_scorecard(str(sd.id)) + mock_delete.assert_called_once() + + +def test_delete_scorecard_returns_confirmation_with_name_id_project(db_session_mock): + from testgen.mcp.tools.quality_scores import delete_scorecard + + sd = _fake_definition("Sales Quality", project_code="demo") + with _patch_perms(), _patch_get(sd), patch.object(ScoreDefinition, "delete"): + out = delete_scorecard(str(sd.id)) + assert "Sales Quality" in out + assert str(sd.id) in out + assert "demo" in out + + +# --- update_scorecard tool --- + + +def _patch_orchestrator(): + """Stub the persist+refresh orchestrator so unit tests don't hit the DB.""" + return patch("testgen.mcp.tools.quality_scores.save_and_refresh_score_definition") + + +def test_update_scorecard_rejects_invalid_uuid(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + with _patch_perms(), pytest.raises(MCPUserError, match="not a valid UUID"): + update_scorecard("not-a-uuid", name="x") + + +def test_update_scorecard_unknown_id_returns_not_accessible(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + missing_id = str(uuid4()) + with _patch_perms(), _patch_get(None), pytest.raises( + MCPResourceNotAccessible, match=missing_id + ): + update_scorecard(missing_id, name="x") + + +def test_update_scorecard_forbidden_project_does_not_call_save(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Other-project card", project_code="forbidden_proj") + with ( + _patch_perms(allowed=("demo",)), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPResourceNotAccessible, match=str(sd.id)), + ): + update_scorecard(str(sd.id), name="x") + mock_orch.assert_not_called() + + +def test_update_scorecard_no_fields_supplied_rejected(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="No fields supplied"), + ): + update_scorecard(str(sd.id)) + mock_orch.assert_not_called() + + +def test_update_scorecard_empty_name_rejected(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="name"), + ): + update_scorecard(str(sd.id), name="") + mock_orch.assert_not_called() + assert sd.name == "Sales Quality" + + +def test_update_scorecard_unknown_category_rejected(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="category"), + ): + update_scorecard(str(sd.id), category="not_a_category") + mock_orch.assert_not_called() + + +def test_update_scorecard_filter_without_field_rejected(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="field"), + ): + update_scorecard(str(sd.id), filters=[{"value": "x"}]) + mock_orch.assert_not_called() + + +def test_update_scorecard_empty_filters_list_rejected(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="filter"), + ): + update_scorecard(str(sd.id), filters=[]) + mock_orch.assert_not_called() + + +def test_update_scorecard_changes_name(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with _patch_perms(), _patch_get(sd), _patch_orchestrator() as mock_orch: + update_scorecard(str(sd.id), name="Renamed Card") + assert sd.name == "Renamed Card" + mock_orch.assert_called_once() + + +def test_update_scorecard_toggles_show_total_score(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality", total=True) + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + update_scorecard(str(sd.id), show_total_score=False) + assert sd.total_score is False + + +def test_update_scorecard_toggles_show_cde_score(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality", cde=False) + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + update_scorecard(str(sd.id), show_cde_score=True) + assert sd.cde_score is True + + +def test_update_scorecard_sets_category(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality", category=None) + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + update_scorecard(str(sd.id), category="Quality Dimension") + assert sd.category == ScoreCategory.dq_dimension + + +def test_update_scorecard_clears_category(db_session_mock): + """Passing an empty ``category`` clears it — distinct from ``None`` (no change).""" + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality", category=ScoreCategory.dq_dimension) + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + update_scorecard(str(sd.id), category="") + assert sd.category is None + + +def test_update_scorecard_replaces_filters(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition( + "Sales Quality", + filters=[{"field": "table_groups_name", "value": "tg1"}], + ) + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + update_scorecard( + str(sd.id), + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + new_filters = list(sd.criteria) + assert len(new_filters) == 1 + assert new_filters[0]["field"] == "business_domain" + assert new_filters[0]["value"] == "Finance" + + +def test_update_scorecard_flat_filters_derive_group_by_field_true(db_session_mock): + """Mode 1 shape (flat category filters) → group_by_field=True, regardless of prior state.""" + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition( + "Sales Quality", + filters=[{ + "field": "table_groups_name", + "value": "sales", + "others": [{"field": "table_name", "value": "orders"}], + }], + group_by_field=False, + ) + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + update_scorecard( + str(sd.id), + filters=[{"field": "Business Domain", "value": "Finance"}], + ) + assert sd.criteria.group_by_field is True + + +def test_update_scorecard_chained_filters_derive_group_by_field_false(db_session_mock): + """Mode 2 shape (any chained filter) → group_by_field=False, regardless of prior state.""" + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition( + "Sales Quality", + filters=[{"field": "table_groups_name", "value": "tg1"}], + group_by_field=True, + ) + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + update_scorecard( + str(sd.id), + filters=[ + { + "field": "Table Group", + "value": "sales", + "others": [{"field": "Table", "value": "orders"}], + }, + { + "field": "Table Group", + "value": "sales", + "others": [{"field": "Table", "value": "customers"}], + }, + ], + ) + assert sd.criteria.group_by_field is False + + +def test_update_scorecard_mode_1_filter_with_non_category_field_rejected(db_session_mock): + """Flat filter using "Table" (chain-leaf field) must be rejected at the flat level.""" + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="Table"), + ): + update_scorecard( + str(sd.id), + filters=[{"field": "Table", "value": "orders"}], + ) + mock_orch.assert_not_called() + + +def test_update_scorecard_mode_2_chain_must_root_at_table_group(db_session_mock): + """Chained filters must start at "Table Group" (matches UI column-selector shape).""" + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="Table Group"), + ): + update_scorecard( + str(sd.id), + filters=[{ + "field": "Data Source", + "value": "S", + "others": [{"field": "Table", "value": "x"}], + }], + ) + mock_orch.assert_not_called() + + +def test_update_scorecard_mode_2_chain_must_chain_into_table_or_column(db_session_mock): + """Chain leaves must be "Table" or "Column" — not category fields.""" + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="Business Domain"), + ): + update_scorecard( + str(sd.id), + filters=[{ + "field": "Table Group", + "value": "sales", + "others": [{"field": "Business Domain", "value": "Finance"}], + }], + ) + mock_orch.assert_not_called() + + +def test_update_scorecard_mode_2_chain_table_then_column_accepted(db_session_mock): + """A full chain "Table Group" → "Table" → "Column" is valid.""" + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + update_scorecard( + str(sd.id), + filters=[{ + "field": "Table Group", + "value": "sales", + "others": [ + {"field": "Table", "value": "orders"}, + {"field": "Column", "value": "id"}, + ], + }], + ) + assert sd.criteria.group_by_field is False + roots = list(sd.criteria) + assert roots[0]["others"][0]["field"] == "table_name" + assert roots[0]["others"][1]["field"] == "column_name" + + +def test_update_scorecard_diff_uses_display_labels(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality", total=True, category=None) + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + out = update_scorecard( + str(sd.id), + show_total_score=False, + category="Quality Dimension", + ) + assert "Total Score" in out + assert "Category" in out + assert "Quality Dimension" in out + # Internal names must not leak. + assert "total_score" not in out + assert "dq_dimension" not in out + + +def test_update_scorecard_diff_omits_unchanged_fields(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality", total=True, cde=False, category=None) + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + out = update_scorecard(str(sd.id), name="Renamed") + assert "Name" in out + assert "Total Score" not in out + assert "CDE Score" not in out + assert "Category" not in out + assert "Filters" not in out + + +def test_update_scorecard_response_includes_id_and_project(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality", project_code="demo") + with _patch_perms(), _patch_get(sd), _patch_orchestrator(): + out = update_scorecard(str(sd.id), name="Renamed") + assert str(sd.id) in out + assert "demo" in out + + +def test_update_scorecard_calls_save_and_refresh_with_is_new_false(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with _patch_perms(), _patch_get(sd), _patch_orchestrator() as mock_orch: + update_scorecard(str(sd.id), name="Renamed") + mock_orch.assert_called_once() + args, kwargs = mock_orch.call_args + assert args[0] is sd + assert kwargs == {"is_new": False} + + +def test_update_scorecard_does_not_call_orchestrator_on_filter_validation_failure(db_session_mock): + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="field"), + ): + update_scorecard( + str(sd.id), + name="Renamed", + filters=[{"value": "x"}], + ) + mock_orch.assert_not_called() + # Name must not be mutated when a later validation step rejects the payload. + assert sd.name == "Sales Quality" + + +# --- create_scorecard --- + + +_VALID_FILTER = [{"field": "Table Group", "value": "tg1"}] + + +def test_create_scorecard_unknown_project_returns_not_accessible(db_session_mock): + from testgen.mcp.tools.quality_scores import create_scorecard + + with ( + _patch_perms(allowed=("demo",)), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPResourceNotAccessible, match="forbidden_proj"), + ): + create_scorecard("forbidden_proj", "My Card", filters=_VALID_FILTER) + mock_orch.assert_not_called() + + +def test_create_scorecard_rejects_blank_name(db_session_mock): + from testgen.mcp.tools.quality_scores import create_scorecard + + with ( + _patch_perms(), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="name"), + ): + create_scorecard("demo", " ", filters=_VALID_FILTER) + mock_orch.assert_not_called() + + +def test_create_scorecard_requires_filters(db_session_mock): + from testgen.mcp.tools.quality_scores import create_scorecard + + with ( + _patch_perms(), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="filter"), + ): + create_scorecard("demo", "My Card", filters=[]) + mock_orch.assert_not_called() + + +def test_create_scorecard_rejects_invalid_filter_field(db_session_mock): + """dq_dimension is a group_by field, not a flat scorecard filter field.""" + from testgen.mcp.tools.quality_scores import create_scorecard + + with ( + _patch_perms(), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="dq_dimension"), + ): + create_scorecard( + "demo", + "My Card", + filters=[{"field": "dq_dimension", "value": "Validity"}], + ) + mock_orch.assert_not_called() + + +def test_create_scorecard_rejects_filter_value_with_forbidden_chars(db_session_mock): + """Persisted scorecard filters must reject SQL-injection chars — values flow + into raw SQL via ``ScoreDefinitionCriteria.get_as_sql``.""" + from testgen.mcp.tools.quality_scores import create_scorecard + + with ( + _patch_perms(), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="forbidden"), + ): + create_scorecard( + "demo", + "My Card", + filters=[{"field": "Table Group", "value": "tg1' OR '1'='1"}], + ) + mock_orch.assert_not_called() + + +def test_create_scorecard_rejects_filter_value_too_long(db_session_mock): + """Persisted scorecard filter values must respect ``_VALUE_MAX_LEN``.""" + from testgen.mcp.tools.quality_scores import create_scorecard + + with ( + _patch_perms(), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="too long"), + ): + create_scorecard( + "demo", + "My Card", + filters=[{"field": "Table Group", "value": "x" * 300}], + ) + mock_orch.assert_not_called() + + +def test_create_scorecard_rejects_chain_leaf_value_with_forbidden_chars(db_session_mock): + """Chain-leaf values (``others[].value``) also flow into raw SQL — same check.""" + from testgen.mcp.tools.quality_scores import create_scorecard + + with ( + _patch_perms(), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="forbidden"), + ): + create_scorecard( + "demo", + "My Card", + filters=[{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "t'; DROP TABLE--"}], + }], + ) + mock_orch.assert_not_called() + + +def test_update_scorecard_rejects_filter_value_with_forbidden_chars(db_session_mock): + """Update path mirrors create — persisted filter values must be safe.""" + from testgen.mcp.tools.quality_scores import update_scorecard + + sd = _fake_definition("Sales Quality") + with ( + _patch_perms(), + _patch_get(sd), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="forbidden"), + ): + update_scorecard( + str(sd.id), + filters=[{"field": "Table Group", "value": 'tg1"'}], + ) + mock_orch.assert_not_called() + + +def test_create_scorecard_rejects_invalid_category(db_session_mock): + from testgen.mcp.tools.quality_scores import create_scorecard + + with ( + _patch_perms(), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError, match="Invalid category"), + ): + create_scorecard( + "demo", + "My Card", + filters=_VALID_FILTER, + category="Not A Category", + ) + mock_orch.assert_not_called() + + +def test_create_scorecard_persists_with_defaults(db_session_mock): + from testgen.mcp.tools.quality_scores import create_scorecard + + with _patch_perms(), _patch_orchestrator() as mock_orch: + create_scorecard("demo", "My Card", filters=_VALID_FILTER) + + assert mock_orch.call_count == 1 + saved = mock_orch.call_args.args[0] + assert isinstance(saved, ScoreDefinition) + assert saved.project_code == "demo" + assert saved.name == "My Card" + assert saved.total_score is True + assert saved.cde_score is False + assert saved.category is None + assert saved.criteria.group_by_field is True + assert saved.criteria.filters[0].field == "table_groups_name" + assert saved.criteria.filters[0].value == "tg1" + assert mock_orch.call_args.kwargs == {"is_new": True} + + +def test_create_scorecard_persists_with_overrides(db_session_mock): + from testgen.mcp.tools.quality_scores import create_scorecard + + with _patch_perms(), _patch_orchestrator() as mock_orch: + create_scorecard( + "demo", + "My Card", + filters=_VALID_FILTER, + category="Quality Dimension", + show_total_score=False, + show_cde_score=True, + ) + + saved = mock_orch.call_args.args[0] + assert saved.total_score is False + assert saved.cde_score is True + assert saved.category == ScoreCategory.dq_dimension + + +def test_create_scorecard_persists_mode_2_chained_filters(db_session_mock): + from testgen.mcp.tools.quality_scores import create_scorecard + + chained = [{ + "field": "Table Group", + "value": "tg1", + "others": [ + {"field": "Table", "value": "accounts"}, + {"field": "Column", "value": "id"}, + ], + }] + with _patch_perms(), _patch_orchestrator() as mock_orch: + create_scorecard("demo", "My Card", filters=chained) + + saved = mock_orch.call_args.args[0] + assert saved.criteria.group_by_field is False + root = saved.criteria.filters[0] + assert root.field == "table_groups_name" + assert root.value == "tg1" + assert root.next_filter is not None + assert root.next_filter.field == "table_name" + assert root.next_filter.next_filter.field == "column_name" + + +def test_create_scorecard_returns_markdown_summary(db_session_mock): + from testgen.mcp.tools.quality_scores import create_scorecard + + new_id = uuid4() + + def _set_id(definition, *, is_new): + definition.id = new_id + return definition + + with _patch_perms(), _patch_orchestrator() as mock_orch: + mock_orch.side_effect = _set_id + out = create_scorecard( + "demo", + "Finance Card", + filters=_VALID_FILTER, + category="Quality Dimension", + ) + + assert "Finance Card" in out + assert "demo" in out + assert str(new_id) in out + # Display label uses "Category", not "Breakdown By". + assert "Category" in out + assert "Breakdown By" not in out + # Friendly category label, not internal column name. + assert "Quality Dimension" in out + assert "dq_dimension" not in out + # Filter summary appears. + assert "Filters" in out + + +# ============================================================ +# Exhaustive corner-case coverage for the unified _validate_filters +# Each numbered test maps 1-to-1 to a case in the plan's Task 2 enumeration. +# Calls into _validate_filters directly (no MCP wrapper) except +# where the contract requires going through the tool itself. +# ============================================================ + +from testgen.mcp.tools.common import ( + SCORE_FILTER_FIELD_TO_COLUMN, + ScoreFilterField, +) +from testgen.mcp.tools.quality_scores import _validate_filters + +# --- A. Shape / required-field rejections --- + + +def test_validate_filters_case_01_empty_list_rejected(): + # case 1 + with pytest.raises(MCPUserError, match=r"At least one filter is required\."): + _validate_filters([]) + + +def test_validate_filters_case_02_missing_field_key_rejected(): + # case 2 + with pytest.raises(MCPUserError, match=r"filters\[0\].*field.*value"): + _validate_filters([{"value": "tg1"}]) + + +def test_validate_filters_case_03_missing_value_key_rejected(): + # case 3 + with pytest.raises(MCPUserError, match=r"filters\[0\].*field.*value"): + _validate_filters([{"field": "Table Group"}]) + + +def test_validate_filters_case_04_empty_string_field_rejected(): + # case 4 + with pytest.raises(MCPUserError, match=r"filters\[0\].*field.*value"): + _validate_filters([{"field": "", "value": "tg1"}]) + + +def test_validate_filters_case_05_empty_string_value_rejected(): + # case 5 + with pytest.raises(MCPUserError, match=r"filters\[0\].*field.*value"): + _validate_filters([{"field": "Table Group", "value": ""}]) + + +def test_validate_filters_case_06_none_field_rejected(): + # case 6 + with pytest.raises(MCPUserError, match=r"filters\[0\].*field.*value"): + _validate_filters([{"field": None, "value": "tg1"}]) + + +def test_validate_filters_case_07_none_value_rejected(): + # case 7 + with pytest.raises(MCPUserError, match=r"filters\[0\].*field.*value"): + _validate_filters([{"field": "Table Group", "value": None}]) + + +def test_validate_filters_case_08_second_filter_malformed_indexed_at_1(): + # case 8 — index propagation through enumerate + with pytest.raises(MCPUserError, match=r"filters\[1\]"): + _validate_filters([ + {"field": "Table Group", "value": "tg1"}, + {"field": "Table Group"}, + ]) + + +# --- B. SQL-injection value guard (flat path) --- + + +def test_validate_filters_case_09_value_with_single_quote_rejected(): + # case 9 + with pytest.raises(MCPUserError, match="forbidden"): + _validate_filters([{"field": "Table Group", "value": "tg1' OR '1'='1"}]) + + +def test_validate_filters_case_10_value_with_double_quote_rejected(): + # case 10 + with pytest.raises(MCPUserError, match="forbidden"): + _validate_filters([{"field": "Table Group", "value": 'tg1"'}]) + + +def test_validate_filters_case_11_value_with_semicolon_rejected(): + # case 11 + with pytest.raises(MCPUserError, match="forbidden"): + _validate_filters([{"field": "Table Group", "value": "tg1; DROP"}]) + + +def test_validate_filters_case_12_value_with_backslash_rejected(): + # case 12 + with pytest.raises(MCPUserError, match="forbidden"): + _validate_filters([{"field": "Table Group", "value": "tg1\\foo"}]) + + +def test_validate_filters_case_13_value_with_null_byte_rejected(): + # case 13 + with pytest.raises(MCPUserError, match="forbidden"): + _validate_filters([{"field": "Table Group", "value": "tg1\x00"}]) + + +def test_validate_filters_case_14_value_length_257_rejected(): + # case 14 — boundary: 257 > 256 limit + with pytest.raises(MCPUserError, match="too long"): + _validate_filters([{"field": "Table Group", "value": "x" * 257}]) + + +def test_validate_filters_case_15_value_length_256_accepted(): + # case 15 — boundary: 256 == limit, accepted + parsed, group_by_field = _validate_filters( + [{"field": "Table Group", "value": "x" * 256}] + ) + assert group_by_field is True + assert parsed[0]["field"] == "table_groups_name" + assert parsed[0]["value"] == "x" * 256 + + +@pytest.mark.parametrize( + "bad_value", + [123, [1, 2], {"k": "v"}, True], + ids=["case_16_int", "case_16_list", "case_16_dict", "case_16_bool"], +) +def test_validate_filters_case_16_value_non_string_rejected(bad_value): + # case 16 + with pytest.raises(MCPUserError, match="must be a string"): + _validate_filters([{"field": "Table Group", "value": bad_value}]) + + +# --- C. Mode 1 (flat, no others) — happy paths --- + + +def test_validate_filters_case_17_single_table_group_flat(): + # case 17 + parsed, group_by_field = _validate_filters( + [{"field": "Table Group", "value": "tg1"}] + ) + assert group_by_field is True + assert parsed == [{"field": "table_groups_name", "value": "tg1"}] + + +def test_validate_filters_case_18_two_filters_same_field(): + # case 18 — same display field, two values + parsed, group_by_field = _validate_filters([ + {"field": "Table Group", "value": "tg1"}, + {"field": "Table Group", "value": "tg2"}, + ]) + assert group_by_field is True + assert parsed == [ + {"field": "table_groups_name", "value": "tg1"}, + {"field": "table_groups_name", "value": "tg2"}, + ] + + +def test_validate_filters_case_19_two_filters_different_fields(): + # case 19 — different display fields + parsed, group_by_field = _validate_filters([ + {"field": "Table Group", "value": "tg1"}, + {"field": "Data Source", "value": "Postgres"}, + ]) + assert group_by_field is True + assert parsed == [ + {"field": "table_groups_name", "value": "tg1"}, + {"field": "data_source", "value": "Postgres"}, + ] + + +@pytest.mark.parametrize( + "field_enum", + list(ScoreFilterField), + ids=[f"case_20_{f.name}" for f in ScoreFilterField], +) +def test_validate_filters_case_20_every_score_filter_field_accepted(field_enum): + # case 20 — parametrize over every ScoreFilterField; assert translation + parsed, group_by_field = _validate_filters( + [{"field": field_enum.value, "value": "val"}] + ) + assert group_by_field is True + assert parsed[0]["field"] == SCORE_FILTER_FIELD_TO_COLUMN[field_enum] + assert parsed[0]["value"] == "val" + + +# --- D. Mode 1 rejection paths --- + + +def test_validate_filters_case_21_column_form_field_rejected(): + # case 21 — column-form `data_source` must be rejected; error lists display values + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{"field": "data_source", "value": "Postgres"}]) + msg = exc_info.value.args[0] + assert "Data Source" in msg + # Column-form must NOT appear as a "valid" suggestion + assert "`data_source`" in msg # the rejected value is quoted back + + +def test_validate_filters_case_22_lowercase_quality_dimension_rejected(): + # case 22 — case-sensitive enum lookup + with pytest.raises(MCPUserError, match="quality dimension"): + _validate_filters([{"field": "quality dimension", "value": "Validity"}]) + + +def test_validate_filters_case_23_quality_dimension_rejected_as_filter_field(): + # case 23 — valid group_by, not a valid filter field + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{"field": "Quality Dimension", "value": "Validity"}]) + assert "Quality Dimension" in exc_info.value.args[0] + + +def test_validate_filters_case_24_impact_dimension_rejected_as_filter_field(): + # case 24 + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{"field": "Impact Dimension", "value": "High"}]) + assert "Impact Dimension" in exc_info.value.args[0] + + +def test_validate_filters_case_25_invalid_field_xyz_rejected(): + # case 25 — totally bogus field + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{"field": "xyz", "value": "v"}]) + msg = exc_info.value.args[0] + assert "xyz" in msg + # Error should list display-form values + assert "Table Group" in msg + + +def test_validate_filters_case_26_empty_others_list_still_mode_1(): + # case 26 — others=[] is falsy in any(...) + parsed, group_by_field = _validate_filters( + [{"field": "Table Group", "value": "tg1", "others": []}] + ) + assert group_by_field is True + assert parsed[0]["field"] == "table_groups_name" + + +def test_validate_filters_case_27_none_others_still_mode_1(): + # case 27 — others=None is falsy in any(...) + parsed, group_by_field = _validate_filters( + [{"field": "Table Group", "value": "tg1", "others": None}] + ) + assert group_by_field is True + assert parsed[0]["field"] == "table_groups_name" + + +# --- E. Mode 2 (chained) — happy paths --- + + +def test_validate_filters_case_28_single_chain_one_step_table(): + # case 28 — Table Group → Table + parsed, group_by_field = _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "orders"}], + }]) + assert group_by_field is False + assert parsed == [{ + "field": "table_groups_name", + "value": "tg1", + "others": [{"field": "table_name", "value": "orders"}], + }] + + +def test_validate_filters_case_29_single_chain_two_steps_table_column(): + # case 29 — Table Group → Table → Column + parsed, group_by_field = _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [ + {"field": "Table", "value": "orders"}, + {"field": "Column", "value": "id"}, + ], + }]) + assert group_by_field is False + assert parsed[0]["others"] == [ + {"field": "table_name", "value": "orders"}, + {"field": "column_name", "value": "id"}, + ] + + +def test_validate_filters_case_30_mode_2_with_sibling_flat_table_group(): + # case 30 — chain-having filter + bare Table Group (entire-tg case) + parsed, group_by_field = _validate_filters([ + { + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "orders"}], + }, + {"field": "Table Group", "value": "tg2"}, + ]) + assert group_by_field is False + assert len(parsed) == 2 + assert parsed[1]["field"] == "table_groups_name" + assert parsed[1]["value"] == "tg2" + + +def test_validate_filters_case_31_multiple_chained_filters_same_shape(): + # case 31 — sibling OR semantics, both translated + parsed, group_by_field = _validate_filters([ + { + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "orders"}], + }, + { + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "customers"}], + }, + ]) + assert group_by_field is False + assert len(parsed) == 2 + for filter_ in parsed: + assert filter_["field"] == "table_groups_name" + assert filter_["others"][0]["field"] == "table_name" + + +def test_validate_filters_case_32_chain_leaf_value_length_256_accepted(): + # case 32 — boundary at chain leaf + parsed, group_by_field = _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "x" * 256}], + }]) + assert group_by_field is False + assert parsed[0]["others"][0]["value"] == "x" * 256 + + +# --- F. Mode 2 rejection paths --- + + +def test_validate_filters_case_33_root_not_table_group_with_others_rejected(): + # case 33 — has others but root is Data Source + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{ + "field": "Data Source", + "value": "S", + "others": [{"field": "Table", "value": "x"}], + }]) + assert "Table Group" in exc_info.value.args[0] + + +def test_validate_filters_case_34_sibling_with_data_source_root_in_chain_mode_rejected(): + # case 34 — one filter chains; sibling has Data Source root (no chain) → reject + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([ + { + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "orders"}], + }, + {"field": "Data Source", "value": "Postgres"}, + ]) + assert "Table Group" in exc_info.value.args[0] + + +def test_validate_filters_case_35_column_without_preceding_table_rejected(): + # case 35 + with pytest.raises(MCPUserError, match="`Column` chain requires a `Table` step"): + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Column", "value": "id"}], + }]) + + +def test_validate_filters_case_36_chain_order_column_then_table_rejected(): + # case 36 + with pytest.raises(MCPUserError, match="`Column` must be the final chain step"): + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [ + {"field": "Column", "value": "id"}, + {"field": "Table", "value": "orders"}, + ], + }]) + + +def test_validate_filters_case_37_chain_leaf_column_form_table_name_rejected(): + # case 37 — column-form leaf `table_name` rejected (display-form only) + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "table_name", "value": "orders"}], + }]) + msg = exc_info.value.args[0] + assert "table_name" in msg # the rejected field is quoted in the error + # Valid leaves listed in display form + assert "Table" in msg + assert "Column" in msg + + +def test_validate_filters_case_38_invalid_chain_leaf_field_rejected(): + # case 38 + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "something_else", "value": "v"}], + }]) + msg = exc_info.value.args[0] + assert "something_else" in msg + assert "Table" in msg + assert "Column" in msg + + +def test_validate_filters_case_39_chain_leaf_missing_field_rejected(): + # case 39 — indexed filters[0].others[0] + with pytest.raises(MCPUserError, match=r"filters\[0\]\.others\[0\].*field.*value"): + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"value": "orders"}], + }]) + + +def test_validate_filters_case_40_chain_leaf_missing_value_rejected(): + # case 40 — indexed filters[0].others[0] + with pytest.raises(MCPUserError, match=r"filters\[0\]\.others\[0\].*field.*value"): + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table"}], + }]) + + +def test_validate_filters_case_41_chain_leaf_value_with_forbidden_char_rejected(): + # case 41 — indexed + with pytest.raises(MCPUserError, match=r"filters\[0\]\.others\[0\].*forbidden"): + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "x'; DROP"}], + }]) + + +def test_validate_filters_case_42_chain_leaf_value_too_long_rejected(): + # case 42 — indexed + with pytest.raises(MCPUserError, match=r"filters\[0\]\.others\[0\].*too long"): + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "x" * 300}], + }]) + + +def test_validate_filters_case_43_chain_leaf_value_non_string_rejected(): + # case 43 — indexed + with pytest.raises(MCPUserError, match=r"filters\[0\]\.others\[0\].*must be a string"): + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": 123}], + }]) + + +def test_validate_filters_case_44_chain_with_extra_trailing_column_rejected(): + # case 44 — Table → Column → Column: second Column is in prefix, not the end + with pytest.raises(MCPUserError, match="`Column` must be the final chain step"): + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [ + {"field": "Table", "value": "orders"}, + {"field": "Column", "value": "id"}, + {"field": "Column", "value": "name"}, + ], + }]) + + +# --- G. Translation correctness (output-shape) --- + + +def test_validate_filters_case_45_output_has_only_column_form_keys(): + # case 45 — every returned `field` is column-form + parsed, _ = _validate_filters([ + {"field": "Table Group", "value": "tg1"}, + {"field": "Data Source", "value": "Postgres"}, + {"field": "Quality Dimension", "value": "ignored"}, # would fail; remove + ][:2]) # only the first two — Quality Dimension isn't a valid filter + column_form_field_values = set(SCORE_FILTER_FIELD_TO_COLUMN.values()) | { + "table_name", "column_name", + } + for filter_ in parsed: + assert filter_["field"] in column_form_field_values + + # Also test chain leaves + parsed_chain, _ = _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [ + {"field": "Table", "value": "orders"}, + {"field": "Column", "value": "id"}, + ], + }]) + for filter_ in parsed_chain: + assert filter_["field"] in column_form_field_values + for leaf in filter_.get("others", []): + assert leaf["field"] in column_form_field_values + + +def test_validate_filters_case_46_values_byte_identical_to_input(): + # case 46 — values are NEVER mutated by translation + raw = [{ + "field": "Table Group", + "value": "MixedCaseTG.with-dots_underscores", + "others": [ + {"field": "Table", "value": "Orders Table"}, + {"field": "Column", "value": "ID-Col"}, + ], + }] + parsed, _ = _validate_filters(raw) + assert parsed[0]["value"] == "MixedCaseTG.with-dots_underscores" + assert parsed[0]["others"][0]["value"] == "Orders Table" + assert parsed[0]["others"][1]["value"] == "ID-Col" + + +def test_validate_filters_case_47_group_by_field_flag_correct(): + # case 47 — True iff no filter has non-empty others + _, flag_flat = _validate_filters([{"field": "Table Group", "value": "tg1"}]) + assert flag_flat is True + _, flag_chained = _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "orders"}], + }]) + assert flag_chained is False + + +# --- H. Error-message hygiene (regression guards) --- + + +def test_validate_filters_case_48_flat_error_message_uses_display_form(): + # case 48 — error mentions valid flat fields: at least one display-form value; + # no underscore-form column names + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{"field": "xyz", "value": "v"}]) + msg = exc_info.value.args[0] + assert "Table Group" in msg # display-form present + # None of the column-form values should appear as a "valid" suggestion + # (the rejected `xyz` is fine; we're checking valid-values listing) + column_form_values = set(SCORE_FILTER_FIELD_TO_COLUMN.values()) + for col_value in column_form_values: + # Each column-form value (e.g. "table_groups_name", "data_source") must + # not appear in the listed-valid set. The rejected field name is also + # mentioned, but that's a user-supplied string, not "xyz" matching. + assert col_value not in msg, ( + f"Error message must not list column-form `{col_value}` as a valid value. " + f"Full message: {msg}" + ) + + +def test_validate_filters_case_49_chain_leaf_error_uses_display_form(): + # case 49 — leaf error mentions Table and Column, not table_name/column_name + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "something_else", "value": "v"}], + }]) + msg = exc_info.value.args[0] + assert "Table" in msg + assert "Column" in msg + # The column-form leaf names must not appear as "valid" leaves + assert "table_name" not in msg.replace("`Table`", "").replace("Table", "") # `Table` ok, `table_name` not + assert "column_name" not in msg.replace("`Column`", "").replace("Column", "") + + +# --- Wrapper-level: column-form rejected through create_scorecard --- + + +def test_create_scorecard_rejects_column_form_field_through_wrapper(db_session_mock): + # case 21 mirrored at the MCP wrapper level + from testgen.mcp.tools.quality_scores import create_scorecard + + with ( + _patch_perms(), + _patch_orchestrator() as mock_orch, + pytest.raises(MCPUserError) as exc_info, + ): + create_scorecard( + "demo", + "My Card", + filters=[{"field": "data_source", "value": "Postgres"}], + ) + msg = exc_info.value.args[0] + assert "Data Source" in msg + mock_orch.assert_not_called() + + +# ============================================================ +# Unified validator: allow_empty + multi-error collection +# ============================================================ + + +def test_validate_filters_empty_default_rejected(): + """With the default allow_empty=False, an empty list raises.""" + with pytest.raises(MCPUserError, match=r"At least one filter is required\."): + _validate_filters([]) + + +def test_validate_filters_none_default_rejected(): + """With the default allow_empty=False, None raises.""" + with pytest.raises(MCPUserError, match=r"At least one filter is required\."): + _validate_filters(None) + + +def test_validate_filters_empty_allowed_returns_empty_tuple(): + """allow_empty=True short-circuits an empty list to ([], True).""" + parsed, group_by_field = _validate_filters([], allow_empty=True) + assert parsed == [] + assert group_by_field is True + + +def test_validate_filters_none_allowed_returns_empty_tuple(): + """allow_empty=True short-circuits None to ([], True).""" + parsed, group_by_field = _validate_filters(None, allow_empty=True) + assert parsed == [] + assert group_by_field is True + + +def test_validate_filters_collects_multiple_flat_errors(): + """Multi-error collection: every offending entry is named in the message.""" + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([ + {"field": "Quality Dimension", "value": "Accuracy"}, # not a filter field + {"field": "Business Domain", "value": "x';--"}, # forbidden chars + {"field": "Data Source", "value": ""}, # empty value + ]) + msg = exc_info.value.args[0] + assert "Quality Dimension" in msg + assert "Business Domain" in msg + assert "Data Source" in msg + + +def test_validate_filters_collects_multiple_chain_leaf_errors(): + """Chain-mode also collects per-leaf errors instead of stopping at the first.""" + with pytest.raises(MCPUserError) as exc_info: + _validate_filters([{ + "field": "Table Group", + "value": "tg1", + "others": [ + {"field": "bogus_leaf", "value": "x"}, # invalid leaf field + {"field": "Table", "value": "tbl';DROP"}, # forbidden char in valid leaf + ], + }]) + msg = exc_info.value.args[0] + assert "bogus_leaf" in msg + assert "forbidden" in msg + + +# ============================================================ +# get_quality_scores: mode-2 chained filter support +# ============================================================ + + +@patch("testgen.mcp.tools.quality_scores.ScoreDefinitionCriteria") +@patch("testgen.mcp.tools.quality_scores.ScoreDefinition") +def test_get_quality_scores_accepts_mode_2_chained_filters( + mock_definition_cls, mock_criteria_cls, db_session_mock, +): + """A chained Table Group → Table filter reaches from_filters with group_by_field=False.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + mock_definition = MagicMock() + mock_definition.as_score_card.return_value = _score_card(score=0.9) + mock_definition_cls.return_value = mock_definition + + with _patch_perms(): + get_quality_scores( + project_code="demo", + filters=[{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "orders"}], + }], + ) + + mock_criteria_cls.from_filters.assert_called_once() + args, kwargs = mock_criteria_cls.from_filters.call_args + passed = args[0] + assert kwargs.get("group_by_field") is False + assert passed[0]["field"] == "table_groups_name" + assert passed[0]["value"] == "tg1" + assert passed[0]["others"] == [{"field": "table_name", "value": "orders"}] + + +def test_get_quality_scores_rejects_table_group_id_with_chained_filters(db_session_mock): + """table_group_id + a mode-2 chain conflict — the implicit name filter would + shadow the chain root, so reject explicitly.""" + from testgen.mcp.tools.quality_scores import get_quality_scores + + tg = MagicMock() + tg.id = uuid4() + tg.project_code = "demo" + tg.table_groups_name = "orders_tg" + + with ( + _patch_perms(), + patch("testgen.mcp.tools.common.TableGroup") as mock_tg_cls, + pytest.raises(MCPUserError, match="chained filters"), + ): + mock_tg_cls.get.return_value = tg + get_quality_scores( + table_group_id=str(tg.id), + filters=[{ + "field": "Table Group", + "value": "tg1", + "others": [{"field": "Table", "value": "orders"}], + }], + ) From b104cb14710ec48145b43e440feb7bce5d0d5aea Mon Sep 17 00:00:00 2001 From: Astor Date: Fri, 22 May 2026 16:48:51 -0300 Subject: [PATCH 36/58] refactor(TG-1041): address second round of reviewer feedback - Rename select_page params page_index/page_size to page/limit (1-based, consistent with list_for_suite); do 0-to-1 conversion at call site in view - Remove dead get_test_definitions_count function (superseded by get_test_definitions returning total_count) - Update get_test_definitions docstring to reference select_page and fix description - Move select_page tests into test_test_definition.py and delete separate pagination test file - Drop test_select_page_item_has_no_total_count_field (tested absence of a bug no code path could create) Co-Authored-By: Claude Sonnet 4.6 --- testgen/common/models/test_definition.py | 8 +- testgen/ui/views/test_definitions.py | 41 +---- .../common/models/test_test_definition.py | 141 +++++++++++++++- .../models/test_test_definition_pagination.py | 155 ------------------ 4 files changed, 147 insertions(+), 198 deletions(-) delete mode 100644 tests/unit/common/models/test_test_definition_pagination.py diff --git a/testgen/common/models/test_definition.py b/testgen/common/models/test_definition.py index c7f9ca17..18a3691e 100644 --- a/testgen/common/models/test_definition.py +++ b/testgen/common/models/test_definition.py @@ -442,8 +442,8 @@ def select_page( cls, *clauses, order_by: tuple[str | InstrumentedAttribute] | None = None, - page_index: int = 0, - page_size: int = 500, + page: int = 1, + limit: int = 500, ) -> tuple[list["TestDefinitionSummary"], int]: select_columns = [ getattr(cls, col, None) or getattr(TestType, col) if isinstance(col, str) else col @@ -455,9 +455,7 @@ def select_page( .where(*clauses) .order_by(*(order_by or cls._default_order_by)) ) - return cls._paginate(query, page=page_index + 1, limit=page_size, data_class=TestDefinitionSummary) - - + return cls._paginate(query, page=page, limit=limit, data_class=TestDefinitionSummary) _yn_columns: ClassVar = {"test_active", "lock_refresh"} diff --git a/testgen/ui/views/test_definitions.py b/testgen/ui/views/test_definitions.py index f01f888e..05f3390a 100644 --- a/testgen/ui/views/test_definitions.py +++ b/testgen/ui/views/test_definitions.py @@ -816,12 +816,11 @@ def get_test_definitions( page_size: int = 500, flagged_filter: str | None = None, ) -> tuple[pd.DataFrame, int]: - """Return ``(df, total_count)`` for one page of test definitions. + """Return ``(df, total_count)`` for test definitions matching the given filters. When ``page_index`` is provided (0-based), fetches only that page from - the DB using ``TestDefinition._paginate()``; otherwise fetches all rows - via ``select_where()``. ``total_count`` is the full count of matching - rows regardless of which page is requested. + the DB using ``TestDefinition.select_page()``; otherwise fetches all rows + via ``select_where()``. ``total_count`` is always the full matching count. """ clauses = [TestDefinition.test_suite_id == test_suite.id] if table_name: @@ -856,8 +855,8 @@ def get_test_definitions( test_definitions, total_count = TestDefinition.select_page( *clauses, order_by=order_by_tuple, - page_index=page_index, - page_size=page_size, + page=page_index + 1, + limit=page_size, ) else: test_definitions = TestDefinition.select_where(*clauses, order_by=order_by_tuple) @@ -896,36 +895,6 @@ def get_export_to_observability_display(value: str) -> str: return df, total_count -def get_test_definitions_count( - test_suite: TestSuite, - table_name: str | None = None, - column_name: str | None = None, - test_type: str | None = None, - flagged_filter: str | None = None, -) -> int: - from testgen.ui.services.database_service import fetch_one_from_db - - where_parts = ["test_suite_id = :test_suite_id"] - params: dict = {"test_suite_id": str(test_suite.id)} - if table_name: - where_parts.append("table_name = :table_name") - params["table_name"] = table_name - if column_name: - where_parts.append("column_name ILIKE :column_name") - params["column_name"] = column_name - if test_type: - where_parts.append("test_type = :test_type") - params["test_type"] = test_type - if flagged_filter == "Flagged": - where_parts.append("flagged = true") - elif flagged_filter == "Not Flagged": - where_parts.append("flagged = false") - - query = f"SELECT COUNT(*) as cnt FROM test_definitions WHERE {' AND '.join(where_parts)};" - result = fetch_one_from_db(query, params) - return int(result["cnt"]) if result else 0 - - def get_test_definition_ids( test_suite: TestSuite, table_name: str | None = None, diff --git a/tests/unit/common/models/test_test_definition.py b/tests/unit/common/models/test_test_definition.py index 8d9ffb12..b7cd9431 100644 --- a/tests/unit/common/models/test_test_definition.py +++ b/tests/unit/common/models/test_test_definition.py @@ -1,6 +1,8 @@ -"""Tests for TestDefinition.validate() and TestDefinition.editable_fields().""" +"""Tests for TestDefinition model methods.""" -from unittest.mock import MagicMock +from datetime import datetime +from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest @@ -8,6 +10,7 @@ InvalidTestDefinitionFields, Severity, TestDefinition, + TestDefinitionSummary, _required_fields_for, ) @@ -239,3 +242,137 @@ def test_severity_enum_value_accepted(): tt = make_test_type() td = make_td(column_name="email", threshold_value="10", severity=Severity.FAIL) td.validate(tt) + + +# --- select_page --- + +@pytest.fixture(autouse=True) +def clear_select_page_cache(): + TestDefinition.select_page.clear() + yield + + +def _make_summary_row(table_name: str = "my_table") -> dict: + return { + "id": uuid4(), + "table_groups_id": uuid4(), + "profile_run_id": uuid4(), + "test_type": "CUSTOM", + "test_suite_id": uuid4(), + "test_description": None, + "schema_name": "public", + "table_name": table_name, + "column_name": "col1", + "skip_errors": 0, + "baseline_ct": None, + "baseline_unique_ct": None, + "baseline_value": None, + "baseline_value_ct": None, + "threshold_value": None, + "baseline_sum": None, + "baseline_avg": None, + "baseline_sd": None, + "lower_tolerance": None, + "upper_tolerance": None, + "subset_condition": None, + "groupby_names": None, + "having_condition": None, + "window_date_column": None, + "window_days": None, + "match_schema_name": None, + "match_table_name": None, + "match_column_names": None, + "match_subset_condition": None, + "match_groupby_names": None, + "match_having_condition": None, + "custom_query": None, + "history_calculation": None, + "history_calculation_upper": None, + "history_lookback": None, + "test_active": True, + "test_definition_status": None, + "severity": None, + "lock_refresh": False, + "last_auto_gen_date": None, + "profiling_as_of_date": None, + "last_manual_update": datetime.now(), + "export_to_observability": False, + "prediction": None, + "flagged": False, + "impact_dimension": None, + "test_name_short": "Custom", + "default_test_description": "A test", + "measure_uom": "", + "measure_uom_description": "", + "default_parm_columns": "", + "default_parm_prompts": "", + "default_parm_help": "", + "default_parm_required": "", + "default_severity": "Warning", + "test_scope": "column", + "dq_dimension": "", + "default_impact_dimension": "", + "usage_notes": "", + } + + +@patch("testgen.common.models.entity.get_current_session") +def test_select_page_returns_items_and_total(mock_get_session): + rows = [_make_summary_row("table_a"), _make_summary_row("table_b"), _make_summary_row("table_c")] + mock_session = mock_get_session.return_value + mock_session.scalar.return_value = 3 + mock_session.execute.return_value.mappings.return_value.all.return_value = rows + + items, total = TestDefinition.select_page() + + assert total == 3 + assert len(items) == 3 + assert all(isinstance(item, TestDefinitionSummary) for item in items) + assert items[0].table_name == "table_a" + assert items[2].table_name == "table_c" + + +@patch("testgen.common.models.entity.get_current_session") +def test_select_page_empty_result_returns_zero_total(mock_get_session): + mock_session = mock_get_session.return_value + mock_session.scalar.return_value = 0 + mock_session.execute.return_value.mappings.return_value.all.return_value = [] + + items, total = TestDefinition.select_page() + + assert items == [] + assert total == 0 + + +@patch("testgen.common.models.entity.get_current_session") +def test_select_page_uses_correct_offset_and_limit(mock_get_session): + mock_session = mock_get_session.return_value + mock_session.scalar.return_value = 0 + mock_session.execute.return_value.mappings.return_value.all.return_value = [] + + TestDefinition.select_page(page=3, limit=100) + + call_args = mock_session.execute.call_args + query = call_args[0][0] + compiled = query.compile(compile_kwargs={"literal_binds": True}) + sql = str(compiled) + + assert "LIMIT 100" in sql + assert "OFFSET 200" in sql + + +@patch("testgen.common.models.entity.get_current_session") +def test_select_page_first_page_has_no_offset(mock_get_session): + mock_session = mock_get_session.return_value + mock_session.scalar.return_value = 0 + mock_session.execute.return_value.mappings.return_value.all.return_value = [] + + TestDefinition.select_page(page=1, limit=500) + + call_args = mock_session.execute.call_args + query = call_args[0][0] + compiled = query.compile(compile_kwargs={"literal_binds": True}) + sql = str(compiled) + + assert "LIMIT 500" in sql + assert "OFFSET 0" in sql diff --git a/tests/unit/common/models/test_test_definition_pagination.py b/tests/unit/common/models/test_test_definition_pagination.py deleted file mode 100644 index fa844f90..00000000 --- a/tests/unit/common/models/test_test_definition_pagination.py +++ /dev/null @@ -1,155 +0,0 @@ -from datetime import datetime -from unittest.mock import patch -from uuid import uuid4 - -import pytest - -from testgen.common.models.test_definition import TestDefinition, TestDefinitionSummary - -pytestmark = pytest.mark.unit - - -@pytest.fixture(autouse=True) -def clear_streamlit_cache(): - TestDefinition.select_page.clear() - yield - - -def _make_row(table_name: str = "my_table") -> dict: - """Return a minimal row dict as returned by session.execute().mappings().all().""" - return { - # TestDefinitionSummary fields - "id": uuid4(), - "table_groups_id": uuid4(), - "profile_run_id": uuid4(), - "test_type": "CUSTOM", - "test_suite_id": uuid4(), - "test_description": None, - "schema_name": "public", - "table_name": table_name, - "column_name": "col1", - "skip_errors": 0, - "baseline_ct": None, - "baseline_unique_ct": None, - "baseline_value": None, - "baseline_value_ct": None, - "threshold_value": None, - "baseline_sum": None, - "baseline_avg": None, - "baseline_sd": None, - "lower_tolerance": None, - "upper_tolerance": None, - "subset_condition": None, - "groupby_names": None, - "having_condition": None, - "window_date_column": None, - "window_days": None, - "match_schema_name": None, - "match_table_name": None, - "match_column_names": None, - "match_subset_condition": None, - "match_groupby_names": None, - "match_having_condition": None, - "custom_query": None, - "history_calculation": None, - "history_calculation_upper": None, - "history_lookback": None, - "test_active": True, - "test_definition_status": None, - "severity": None, - "lock_refresh": False, - "last_auto_gen_date": None, - "profiling_as_of_date": None, - "last_manual_update": datetime.now(), - "export_to_observability": False, - "prediction": None, - "flagged": False, - "impact_dimension": None, - # TestTypeSummary fields - "test_name_short": "Custom", - "default_test_description": "A test", - "measure_uom": "", - "measure_uom_description": "", - "default_parm_columns": "", - "default_parm_prompts": "", - "default_parm_help": "", - "default_parm_required": "", - "default_severity": "Warning", - "test_scope": "column", - "dq_dimension": "", - "default_impact_dimension": "", - "usage_notes": "", - } - - -@patch("testgen.common.models.entity.get_current_session") -def test_select_page_returns_items_and_total(mock_get_session): - rows = [_make_row("table_a"), _make_row("table_b"), _make_row("table_c")] - mock_session = mock_get_session.return_value - mock_session.scalar.return_value = 3 - mock_session.execute.return_value.mappings.return_value.all.return_value = rows - - items, total = TestDefinition.select_page() - - assert total == 3 - assert len(items) == 3 - assert all(isinstance(item, TestDefinitionSummary) for item in items) - assert items[0].table_name == "table_a" - assert items[2].table_name == "table_c" - - -@patch("testgen.common.models.entity.get_current_session") -def test_select_page_empty_result_returns_zero_total(mock_get_session): - mock_session = mock_get_session.return_value - mock_session.scalar.return_value = 0 - mock_session.execute.return_value.mappings.return_value.all.return_value = [] - - items, total = TestDefinition.select_page() - - assert items == [] - assert total == 0 - - -@patch("testgen.common.models.entity.get_current_session") -def test_select_page_item_has_no_total_count_field(mock_get_session): - mock_session = mock_get_session.return_value - mock_session.scalar.return_value = 1 - mock_session.execute.return_value.mappings.return_value.all.return_value = [_make_row()] - - items, _ = TestDefinition.select_page() - - assert not hasattr(items[0], "total_count") - - -@patch("testgen.common.models.entity.get_current_session") -def test_select_page_uses_correct_offset_and_limit(mock_get_session): - mock_session = mock_get_session.return_value - mock_session.scalar.return_value = 0 - mock_session.execute.return_value.mappings.return_value.all.return_value = [] - - TestDefinition.select_page(page_index=2, page_size=100) - - call_args = mock_session.execute.call_args - query = call_args[0][0] - compiled = query.compile(compile_kwargs={"literal_binds": True}) - sql = str(compiled) - - assert "LIMIT 100" in sql - assert "OFFSET 200" in sql - - -@patch("testgen.common.models.entity.get_current_session") -def test_select_page_page_zero_has_no_offset(mock_get_session): - mock_session = mock_get_session.return_value - mock_session.scalar.return_value = 0 - mock_session.execute.return_value.mappings.return_value.all.return_value = [] - - TestDefinition.select_page(page_index=0, page_size=500) - - call_args = mock_session.execute.call_args - query = call_args[0][0] - compiled = query.compile(compile_kwargs={"literal_binds": True}) - sql = str(compiled) - - assert "LIMIT 500" in sql - assert "OFFSET 0" in sql From 9f0a452d0608fdc431684586c54f88e4b050f970 Mon Sep 17 00:00:00 2001 From: Luis Date: Tue, 26 May 2026 14:39:21 -0400 Subject: [PATCH 37/58] refactor(models): decouple Streamlit cache from common layer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit @st.cache_data decorators on model methods caused Streamlit's in-process cache to leak into MCP, API, scheduler, and CLI — at least one bug was traced to MCP tools returning stale data because the underlying model method was cached by a decorator that should not have been there. Move all caching out of the model layer and into UI wrappers: - Strip 16 @st.cache_data decorators across 8 files in common/models/ (entity, connection, user, table_group, test_suite, test_run, profiling_run, test_definition); remove the now-unused `import streamlit as st` lines. ENTITY_HASH_FUNCS stays in entity.py — the hash function is tied to SQLAlchemy types and is now imported by ui/services/query_cache.py. - Add ~25 per-(entity, method) cached wrappers in ui/services/query_cache.py covering both the formerly-decorated methods and inherited-from-Entity callers (Project, ProjectMembership). - Migrate ~55 UI call sites and 11 targeted .clear() invalidations to the wrappers; the model methods no longer have .clear() and the wrappers expose it. - Add tests/unit/common/test_no_streamlit_in_common.py as a boundary guard that fails if @st.cache_* or `import streamlit` reappears under common/. - Add tests/unit/ui/services/test_query_cache.py to verify each wrapper exists, is callable, and exposes .clear(). - Drop the now-obsolete clear_select_page_cache fixture from tests/unit/common/models/test_test_definition.py. Co-Authored-By: Claude Opus 4.7 (1M context) --- testgen/common/models/connection.py | 5 +- testgen/common/models/entity.py | 3 - testgen/common/models/profiling_run.py | 5 +- testgen/common/models/table_group.py | 5 +- testgen/common/models/test_definition.py | 7 +- testgen/common/models/test_run.py | 2 - testgen/common/models/test_suite.py | 5 +- testgen/common/models/user.py | 2 - testgen/ui/app.py | 4 +- testgen/ui/auth.py | 12 +- testgen/ui/queries/table_group_queries.py | 3 +- testgen/ui/services/query_cache.py | 185 +++++++++++++++++- testgen/ui/views/connections.py | 13 +- testgen/ui/views/data_catalog.py | 19 +- .../views/dialogs/import_metadata_dialog.py | 6 +- testgen/ui/views/dialogs/run_tests_dialog.py | 4 +- testgen/ui/views/hygiene_issues.py | 4 +- testgen/ui/views/monitors_dashboard.py | 27 ++- testgen/ui/views/profiling_results.py | 4 +- testgen/ui/views/profiling_runs.py | 16 +- testgen/ui/views/project_settings.py | 5 +- testgen/ui/views/table_groups.py | 27 ++- testgen/ui/views/test_definitions.py | 39 ++-- testgen/ui/views/test_results.py | 31 +-- testgen/ui/views/test_runs.py | 18 +- testgen/ui/views/test_suites.py | 23 ++- .../common/models/test_test_definition.py | 6 - .../common/test_no_streamlit_in_common.py | 47 +++++ tests/unit/ui/services/__init__.py | 0 tests/unit/ui/services/test_query_cache.py | 60 ++++++ tests/unit/ui/test_project_settings.py | 4 +- 31 files changed, 455 insertions(+), 136 deletions(-) create mode 100644 tests/unit/common/test_no_streamlit_in_common.py create mode 100644 tests/unit/ui/services/__init__.py create mode 100644 tests/unit/ui/services/test_query_cache.py diff --git a/testgen/common/models/connection.py b/testgen/common/models/connection.py index e209299b..a940edba 100644 --- a/testgen/common/models/connection.py +++ b/testgen/common/models/connection.py @@ -4,7 +4,6 @@ from urllib.parse import parse_qs, urlparse from uuid import UUID, uuid4 -import streamlit as st from sqlalchemy import ( BigInteger, Boolean, @@ -23,7 +22,7 @@ from testgen.common.database.flavor.flavor_service import SQLFlavor from testgen.common.models import get_current_session from testgen.common.models.custom_types import JSON_TYPE, EncryptedBytea, EncryptedJson -from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal +from testgen.common.models.entity import Entity, EntityMinimal from testgen.common.models.table_group import TableGroup from testgen.utils import is_uuid4 @@ -70,7 +69,6 @@ class Connection(Entity): _minimal_columns = ConnectionMinimal.__annotations__.keys() @classmethod - @st.cache_data(show_spinner=False) def get_minimal(cls, identifier: int) -> ConnectionMinimal | None: result = cls._get_columns(identifier, cls._minimal_columns) return ConnectionMinimal(**result) if result else None @@ -84,7 +82,6 @@ def get_by_table_group(cls, table_group_id: str | UUID) -> Self | None: return get_current_session().scalars(query).first() @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) def select_minimal_where( cls, *clauses, order_by: tuple[str | InstrumentedAttribute] = _default_order_by ) -> Iterable[ConnectionMinimal]: diff --git a/testgen/common/models/entity.py b/testgen/common/models/entity.py index 8f055bda..263c4f5a 100644 --- a/testgen/common/models/entity.py +++ b/testgen/common/models/entity.py @@ -3,7 +3,6 @@ from typing import Any, Self from uuid import UUID -import streamlit as st from sqlalchemy import delete, func, select from sqlalchemy.dialects import postgresql from sqlalchemy.orm import InstrumentedAttribute @@ -46,7 +45,6 @@ class Entity(Base): _default_order_by: tuple[str | InstrumentedAttribute] = ("id",) @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) def get(cls, identifier: str | int | UUID, *clauses) -> Self | None: """Fetch by primary key, optionally narrowed by extra WHERE clauses. @@ -89,7 +87,6 @@ def _get_columns( return get_current_session().execute(query).mappings().first() @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) def select_where(cls, *clauses, order_by: tuple[str | InstrumentedAttribute] | None = None) -> Iterable[Self]: order_by = order_by or cls._default_order_by query = select(cls).where(*clauses).order_by(*order_by) diff --git a/testgen/common/models/profiling_run.py b/testgen/common/models/profiling_run.py index 20e69a27..93a1d959 100644 --- a/testgen/common/models/profiling_run.py +++ b/testgen/common/models/profiling_run.py @@ -4,7 +4,6 @@ from typing import ClassVar, Literal, NamedTuple, Self, TypedDict from uuid import UUID, uuid4 -import streamlit as st from sqlalchemy import BigInteger, Column, Float, Integer, String, desc, func, select, text, update from sqlalchemy.dialects import postgresql from sqlalchemy.orm import InstrumentedAttribute @@ -14,7 +13,7 @@ from testgen.common.enums import Disposition, JobStatus from testgen.common.models import get_current_session from testgen.common.models.connection import Connection -from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal +from testgen.common.models.entity import Entity, EntityMinimal from testgen.common.models.job_execution import JobExecution from testgen.common.models.profile_result import ProfileResult from testgen.common.models.project import Project @@ -148,7 +147,6 @@ def get_by_id_or_job(cls, identifier: UUID) -> Self | None: return get_current_session().scalars(query).first() @classmethod - @st.cache_data(show_spinner=False) def get_minimal(cls, run_id: str | UUID) -> ProfilingRunMinimal | None: if not is_uuid4(run_id): return None @@ -193,7 +191,6 @@ def get_latest_complete_je_id_for_table_group(cls, table_groups_id: UUID) -> UUI return get_current_session().scalar(query) @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) def select_minimal_where( cls, *clauses, order_by: tuple[str | InstrumentedAttribute] = _default_order_by ) -> Iterable[ProfilingRunMinimal]: diff --git a/testgen/common/models/table_group.py b/testgen/common/models/table_group.py index 117e8983..251dcc46 100644 --- a/testgen/common/models/table_group.py +++ b/testgen/common/models/table_group.py @@ -3,14 +3,13 @@ from datetime import datetime from uuid import UUID, uuid4 -import streamlit as st from sqlalchemy import BigInteger, Boolean, Column, Float, ForeignKey, Integer, String, asc, func, text, update from sqlalchemy.dialects import postgresql from sqlalchemy.orm import InstrumentedAttribute from testgen.common.models import get_current_session from testgen.common.models.custom_types import NullIfEmptyString, YNString -from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal +from testgen.common.models.entity import Entity, EntityMinimal from testgen.common.models.scores import ScoreDefinition from testgen.common.models.test_suite import TestSuite from testgen.utils import is_uuid4 @@ -151,13 +150,11 @@ class TableGroup(Entity): ) @classmethod - @st.cache_data(show_spinner=False) def get_minimal(cls, id_: str | UUID) -> TableGroupMinimal | None: result = cls._get_columns(id_, cls._minimal_columns) return TableGroupMinimal(**result) if result else None @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) def select_minimal_where( cls, *clauses, order_by: tuple[str | InstrumentedAttribute] = _default_order_by ) -> Iterable[TableGroupMinimal]: diff --git a/testgen/common/models/test_definition.py b/testgen/common/models/test_definition.py index 18a3691e..82b7b53c 100644 --- a/testgen/common/models/test_definition.py +++ b/testgen/common/models/test_definition.py @@ -6,7 +6,6 @@ from typing import ClassVar, Literal from uuid import UUID, uuid4 -import streamlit as st from sqlalchemy import ( Boolean, Column, @@ -28,7 +27,7 @@ from testgen.common.models import Base, get_current_session from testgen.common.models.custom_types import NullIfEmptyString, YNString, ZeroIfEmptyInteger -from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal +from testgen.common.models.entity import Entity, EntityMinimal from testgen.utils import is_uuid4 TestRunType = Literal["QUERY", "CAT", "METADATA"] @@ -331,7 +330,6 @@ class TestDefinition(Entity): ) @classmethod - @st.cache_data(show_spinner=False) def get(cls, identifier: str | UUID) -> TestDefinitionSummary | None: if not is_uuid4(identifier): return None @@ -370,7 +368,6 @@ def get_for_project( return TestDefinitionSummary(**result) if result else None @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) def select_where( cls, *clauses, order_by: tuple[str | InstrumentedAttribute] = _default_order_by ) -> Iterable[TestDefinitionSummary]: @@ -384,7 +381,6 @@ def select_where( return [TestDefinitionSummary(**row) for row in results] @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) def select_minimal_where( cls, *clauses, order_by: tuple[str | InstrumentedAttribute] = _default_order_by ) -> Iterable[TestDefinitionMinimal]: @@ -437,7 +433,6 @@ def list_for_suite( return cls._paginate(query, page=page, limit=limit, data_class=TestDefinitionSummary) @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) def select_page( cls, *clauses, diff --git a/testgen/common/models/test_run.py b/testgen/common/models/test_run.py index 27e52a8b..44ede919 100644 --- a/testgen/common/models/test_run.py +++ b/testgen/common/models/test_run.py @@ -3,7 +3,6 @@ from typing import ClassVar, Literal, NamedTuple, Self, TypedDict from uuid import UUID, uuid4 -import streamlit as st from sqlalchemy import BigInteger, Column, Float, ForeignKey, Integer, String, Text, desc, func, select, text, update from sqlalchemy.dialects import postgresql from sqlalchemy.orm.attributes import flag_modified @@ -161,7 +160,6 @@ def get_job_execution_ids(cls, test_run_ids: list[UUID]) -> dict[UUID, UUID | No return {row.id: row.job_execution_id for row in rows} @classmethod - @st.cache_data(show_spinner=False) def get_minimal(cls, run_id: str | UUID) -> TestRunMinimal | None: if not is_uuid4(run_id): return None diff --git a/testgen/common/models/test_suite.py b/testgen/common/models/test_suite.py index bd396eb1..ac29ddce 100644 --- a/testgen/common/models/test_suite.py +++ b/testgen/common/models/test_suite.py @@ -4,14 +4,13 @@ from datetime import datetime from uuid import UUID, uuid4 -import streamlit as st from sqlalchemy import BigInteger, Boolean, Column, Enum, ForeignKey, Integer, String, asc, func, select, text from sqlalchemy.dialects import postgresql from sqlalchemy.orm import InstrumentedAttribute from testgen.common.models import get_current_session from testgen.common.models.custom_types import NullIfEmptyString, YNString -from testgen.common.models.entity import ENTITY_HASH_FUNCS, Entity, EntityMinimal +from testgen.common.models.entity import Entity, EntityMinimal from testgen.utils import is_uuid4 @@ -94,13 +93,11 @@ def get_regular(cls, identifier: str | UUID) -> "TestSuite | None": @classmethod - @st.cache_data(show_spinner=False) def get_minimal(cls, identifier: int) -> TestSuiteMinimal | None: result = cls._get_columns(identifier, cls._minimal_columns) return TestSuiteMinimal(**result) if result else None @classmethod - @st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) def select_minimal_where( cls, *clauses, order_by: tuple[str | InstrumentedAttribute] = _default_order_by ) -> Iterable[TestSuiteMinimal]: diff --git a/testgen/common/models/user.py b/testgen/common/models/user.py index b4e1d575..c96f271f 100644 --- a/testgen/common/models/user.py +++ b/testgen/common/models/user.py @@ -2,7 +2,6 @@ from typing import Self from uuid import UUID, uuid4 -import streamlit as st from sqlalchemy import Boolean, Column, String, asc, func, select, update from sqlalchemy.dialects import postgresql @@ -42,7 +41,6 @@ def save(self, update_latest_login: bool = False) -> None: super().save() @classmethod - @st.cache_data(show_spinner=False) def get(cls, identifier: str) -> Self | None: query = select(cls).where(func.lower(User.username) == func.lower(identifier)) return get_current_session().scalars(query).first() diff --git a/testgen/ui/app.py b/testgen/ui/app.py index be61a443..b67cb389 100644 --- a/testgen/ui/app.py +++ b/testgen/ui/app.py @@ -8,12 +8,12 @@ from testgen.common import version_service from testgen.common.docker_service import check_basic_configuration from testgen.common.models import get_current_session, with_database_session -from testgen.common.models.project import Project from testgen.common.standalone_postgres import STANDALONE_URI_ENV_VAR, ensure_standalone_setup, is_standalone_mode from testgen.ui import bootstrap from testgen.ui.assets import get_asset_path from testgen.ui.components import widgets as testgen from testgen.ui.services import javascript_service +from testgen.ui.services.query_cache import select_projects_where from testgen.ui.session import session if is_standalone_mode() and (standalone_uri := os.environ.get(STANDALONE_URI_ENV_VAR)): @@ -72,7 +72,7 @@ def render(log_level: int = logging.INFO): with st.sidebar: testgen.sidebar( projects=[] if is_global_context else [ - p for p in Project.select_where() if session.auth.user_has_project_access(p.project_code) + p for p in select_projects_where() if session.auth.user_has_project_access(p.project_code) ], current_project=None if is_global_context else session.sidebar_project, menu=application.menu, diff --git a/testgen/ui/auth.py b/testgen/ui/auth.py index 2abe32e8..8bb5b788 100644 --- a/testgen/ui/auth.py +++ b/testgen/ui/auth.py @@ -9,7 +9,11 @@ from testgen.common.models.project_membership import RoleType from testgen.common.models.user import User from testgen.ui.services.javascript_service import execute_javascript -from testgen.ui.services.query_cache import get_membership_by_user_and_project +from testgen.ui.services.query_cache import ( + get_membership_by_user_and_project, + get_user, + select_users_where, +) from testgen.ui.session import session LOG = logging.getLogger("testgen") @@ -62,7 +66,7 @@ def get_jwt_hashing_key(self) -> bytes: st.stop() def get_credentials(self): - users = User.select_where() + users = select_users_where() usernames = {} for item in users: usernames[item.username.lower()] = { @@ -72,7 +76,7 @@ def get_credentials(self): return {"usernames": usernames} def login_user(self, username: str) -> None: - self.user = User.get(username) + self.user = get_user(username) self.user.save(update_latest_login=True) self.load_user_role() MixpanelService().send_event("login", include_usage=True, role=self.role) @@ -83,7 +87,7 @@ def load_user_session(self) -> None: if token is not None: try: payload = decode_jwt_token(token) - self.user = User.get(payload["username"]) + self.user = get_user(payload["username"]) self.load_user_role() except Exception: LOG.debug("Invalid auth token found on cookies", exc_info=True, stack_info=True) diff --git a/testgen/ui/queries/table_group_queries.py b/testgen/ui/queries/table_group_queries.py index f52fdf0e..0db27e12 100644 --- a/testgen/ui/queries/table_group_queries.py +++ b/testgen/ui/queries/table_group_queries.py @@ -12,6 +12,7 @@ from testgen.common.models.connection import Connection from testgen.common.models.table_group import TableGroup from testgen.ui.services.database_service import fetch_from_target_db +from testgen.ui.services.query_cache import get_connection class StatsPreview(TypedDict): @@ -56,7 +57,7 @@ def get_table_group_preview( if connection or table_group.connection_id: try: - connection = connection or Connection.get(table_group.connection_id) + connection = connection or get_connection(table_group.connection_id) table_group_preview, data_chars, sql_generator = _get_preview(table_group, connection) def save_data_chars(table_group_id: UUID) -> None: diff --git a/testgen/ui/services/query_cache.py b/testgen/ui/services/query_cache.py index c90afe0b..55be1e67 100644 --- a/testgen/ui/services/query_cache.py +++ b/testgen/ui/services/query_cache.py @@ -12,15 +12,28 @@ import streamlit as st -from testgen.common.models.connection import Connection -from testgen.common.models.profiling_run import ProfilingRun, ProfilingRunSummary +from testgen.common.models.connection import Connection, ConnectionMinimal +from testgen.common.models.entity import ENTITY_HASH_FUNCS +from testgen.common.models.profiling_run import ProfilingRun, ProfilingRunMinimal, ProfilingRunSummary from testgen.common.models.project import Project, ProjectSummary from testgen.common.models.project_membership import ProjectMembership from testgen.common.models.scheduler import RUN_MONITORS_JOB_KEY, JobSchedule -from testgen.common.models.table_group import TableGroup, TableGroupStats, TableGroupSummary -from testgen.common.models.test_definition import TestType, TestTypeSummary -from testgen.common.models.test_run import TestRun, TestRunSummary -from testgen.common.models.test_suite import TestSuite, TestSuiteSummary +from testgen.common.models.table_group import ( + TableGroup, + TableGroupMinimal, + TableGroupStats, + TableGroupSummary, +) +from testgen.common.models.test_definition import ( + TestDefinition, + TestDefinitionMinimal, + TestDefinitionSummary, + TestType, + TestTypeSummary, +) +from testgen.common.models.test_run import TestRun, TestRunMinimal, TestRunSummary +from testgen.common.models.test_suite import TestSuite, TestSuiteMinimal, TestSuiteSummary +from testgen.common.models.user import User # -- Project ------------------------------------------------------------------ @@ -137,3 +150,163 @@ def get_monitor_schedule(monitor_suite_id: str | UUID) -> JobSchedule | None: JobSchedule.key == RUN_MONITORS_JOB_KEY, JobSchedule.kwargs["test_suite_id"].astext == str(monitor_suite_id), ) + + +# -- Connection --------------------------------------------------------------- + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def get_connection(identifier: str | int | UUID, *clauses) -> Connection | None: + return Connection.get(identifier, *clauses) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_connections_where(*clauses, order_by=None) -> list[Connection]: + return list(Connection.select_where(*clauses, order_by=order_by)) + + +@st.cache_data(show_spinner=False) +def get_connection_minimal(identifier: int) -> ConnectionMinimal | None: + return Connection.get_minimal(identifier) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_connections_minimal_where(*clauses, order_by=None) -> list[ConnectionMinimal]: + if order_by is None: + return list(Connection.select_minimal_where(*clauses)) + return list(Connection.select_minimal_where(*clauses, order_by=order_by)) + + +# -- User --------------------------------------------------------------------- + +@st.cache_data(show_spinner=False) +def get_user(identifier: str) -> User | None: + return User.get(identifier) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_users_where(*clauses, order_by=None) -> list[User]: + return list(User.select_where(*clauses, order_by=order_by)) + + +# -- TableGroup --------------------------------------------------------------- + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def get_table_group(identifier: str | UUID, *clauses) -> TableGroup | None: + return TableGroup.get(identifier, *clauses) + + +@st.cache_data(show_spinner=False) +def get_table_group_minimal(identifier: str | UUID) -> TableGroupMinimal | None: + return TableGroup.get_minimal(identifier) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_table_groups_minimal_where(*clauses, order_by=None) -> list[TableGroupMinimal]: + if order_by is None: + return list(TableGroup.select_minimal_where(*clauses)) + return list(TableGroup.select_minimal_where(*clauses, order_by=order_by)) + + +# -- TestSuite ---------------------------------------------------------------- + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def get_test_suite(identifier: str | UUID, *clauses) -> TestSuite | None: + return TestSuite.get(identifier, *clauses) + + +@st.cache_data(show_spinner=False) +def get_test_suite_minimal(identifier: int) -> TestSuiteMinimal | None: + return TestSuite.get_minimal(identifier) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_test_suites_minimal_where(*clauses, order_by=None) -> list[TestSuiteMinimal]: + if order_by is None: + return list(TestSuite.select_minimal_where(*clauses)) + return list(TestSuite.select_minimal_where(*clauses, order_by=order_by)) + + +# -- TestRun ------------------------------------------------------------------ + +@st.cache_data(show_spinner=False) +def get_test_run_minimal(run_id: str | UUID) -> TestRunMinimal | None: + return TestRun.get_minimal(run_id) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_test_runs_where(*clauses, order_by=None) -> list[TestRun]: + return list(TestRun.select_where(*clauses, order_by=order_by)) + + +# -- ProfilingRun ------------------------------------------------------------- + +@st.cache_data(show_spinner=False) +def get_profiling_run_minimal(run_id: str | UUID) -> ProfilingRunMinimal | None: + return ProfilingRun.get_minimal(run_id) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_profiling_runs_where(*clauses, order_by=None) -> list[ProfilingRun]: + return list(ProfilingRun.select_where(*clauses, order_by=order_by)) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_profiling_runs_minimal_where(*clauses, order_by=None) -> list[ProfilingRunMinimal]: + if order_by is None: + return list(ProfilingRun.select_minimal_where(*clauses)) + return list(ProfilingRun.select_minimal_where(*clauses, order_by=order_by)) + + +# -- TestDefinition ----------------------------------------------------------- + +@st.cache_data(show_spinner=False) +def get_test_definition(identifier: str | UUID) -> TestDefinitionSummary | None: + return TestDefinition.get(identifier) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_test_definitions_where(*clauses, order_by=None) -> list[TestDefinitionSummary]: + if order_by is None: + return list(TestDefinition.select_where(*clauses)) + return list(TestDefinition.select_where(*clauses, order_by=order_by)) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_test_definitions_minimal_where(*clauses, order_by=None) -> list[TestDefinitionMinimal]: + if order_by is None: + return list(TestDefinition.select_minimal_where(*clauses)) + return list(TestDefinition.select_minimal_where(*clauses, order_by=order_by)) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_test_definitions_page( + *clauses, + order_by=None, + page: int = 1, + limit: int = 500, +) -> tuple[list[TestDefinitionSummary], int]: + return TestDefinition.select_page(*clauses, order_by=order_by, page=page, limit=limit) + + +# -- Project ------------------------------------------------------------------ + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def get_project(identifier: str, *clauses) -> Project | None: + return Project.get(identifier, *clauses) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_projects_where(*clauses, order_by=None) -> list[Project]: + return list(Project.select_where(*clauses, order_by=order_by)) + + +# -- ProjectMembership -------------------------------------------------------- + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def get_project_membership(identifier: str | UUID, *clauses) -> ProjectMembership | None: + return ProjectMembership.get(identifier, *clauses) + + +@st.cache_data(show_spinner=False, hash_funcs=ENTITY_HASH_FUNCS) +def select_project_memberships_where(*clauses, order_by=None) -> list[ProjectMembership]: + return list(ProjectMembership.select_where(*clauses, order_by=order_by)) diff --git a/testgen/ui/views/connections.py b/testgen/ui/views/connections.py index 5faaf0a4..619e10ed 100644 --- a/testgen/ui/views/connections.py +++ b/testgen/ui/views/connections.py @@ -30,6 +30,11 @@ from testgen.ui.components import widgets as testgen from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page +from testgen.ui.services.query_cache import ( + get_connection, + select_connections_where, + select_table_groups_minimal_where, +) from testgen.ui.session import session, temp_value from testgen.ui.utils import get_cron_sample_handler @@ -71,14 +76,14 @@ def render(self, project_code: str, **_kwargs) -> None: "connect-your-database/manage-connections/", ) - connections = Connection.select_where(Connection.project_code == project_code) + connections = select_connections_where(Connection.project_code == project_code) connection: Connection = connections[0] if len(connections) > 0 else Connection( sql_flavor="postgresql", sql_flavor_code="postgresql", project_code=project_code, ) has_table_groups = ( - connection.id and len(TableGroup.select_minimal_where(TableGroup.connection_id == connection.connection_id) or []) > 0 + connection.id and len(select_table_groups_minimal_where(TableGroup.connection_id == connection.connection_id) or []) > 0 ) user_is_admin = session.auth.user_has_permission("administer") @@ -186,8 +191,8 @@ def on_setup_table_group_clicked(*_args) -> None: success = True try: connection.save() - Connection.select_where.clear() - Connection.get.clear() + select_connections_where.clear() + get_connection.clear() message = "Changes have been saved successfully." except Exception as error: message = "Something went wrong while creating the connection." diff --git a/testgen/ui/views/data_catalog.py b/testgen/ui/views/data_catalog.py index 10a6128d..d316f237 100644 --- a/testgen/ui/views/data_catalog.py +++ b/testgen/ui/views/data_catalog.py @@ -47,7 +47,14 @@ get_tables_by_table_group, ) from testgen.ui.services.database_service import execute_db_query, fetch_all_from_db, fetch_from_target_db -from testgen.ui.services.query_cache import get_profiling_run_summaries, get_project_summary, get_table_group_stats +from testgen.ui.services.query_cache import ( + get_profiling_run_summaries, + get_project_summary, + get_table_group, + get_table_group_stats, + select_profiling_runs_minimal_where, + select_table_groups_minimal_where, +) from testgen.ui.session import session from testgen.ui.views.dialogs.import_metadata_dialog import ( apply_metadata_import, @@ -100,7 +107,7 @@ def render(self, project_code: str, table_group_id: str | None = None, selected: project_summary = get_project_summary(project_code) user_can_navigate = session.auth.user_has_permission("view") - table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code) + table_groups = select_table_groups_minimal_where(TableGroup.project_code == project_code) if not table_group_id or table_group_id not in [ str(item.id) for item in table_groups ]: table_group_id = str(table_groups[0].id) if table_groups else None @@ -203,7 +210,7 @@ def on_import_confirmed(_) -> None: try: apply_metadata_import(preview, tg_id) from testgen.ui.queries.profiling_queries import get_column_by_id, get_table_by_id - for func in [get_table_group_columns, get_table_by_id, get_column_by_id, get_tag_values, TableGroup.select_minimal_where]: + for func in [get_table_group_columns, get_table_by_id, get_column_by_id, get_tag_values, select_table_groups_minimal_where]: func.clear() st.session_state["data_catalog:last_saved_timestamp"] = datetime.now().timestamp() parts = [] @@ -664,7 +671,7 @@ def on_tags_changed(spinner_container: DeltaGenerator, payload: dict) -> FILE_DA if disable_flags: table_group_id = st.query_params.get("table_group_id") if table_group_id: - table_group = TableGroup.get(table_group_id) + table_group = get_table_group(table_group_id) changed = False if "profile_flag_cdes" in disable_flags and table_group.profile_flag_cdes: table_group.profile_flag_cdes = False @@ -675,7 +682,7 @@ def on_tags_changed(spinner_container: DeltaGenerator, payload: dict) -> FILE_DA if changed: table_group.save() - for func in [ get_table_group_columns, get_table_by_id, get_column_by_id, get_tag_values, TableGroup.select_minimal_where ]: + for func in [ get_table_group_columns, get_table_by_id, get_column_by_id, get_tag_values, select_table_groups_minimal_where ]: func.clear() st.session_state["data_catalog:last_saved_timestamp"] = datetime.now().timestamp() @@ -832,7 +839,7 @@ def _build_history_dialog_data( column_name: str, add_date: int, ) -> dict | None: - profiling_runs = ProfilingRun.select_minimal_where( + profiling_runs = select_profiling_runs_minimal_where( ProfilingRun.table_groups_id == table_group_id, ProfilingRun.profiling_starttime >= sa_func.to_timestamp(add_date), ) diff --git a/testgen/ui/views/dialogs/import_metadata_dialog.py b/testgen/ui/views/dialogs/import_metadata_dialog.py index 524750ea..b6b1f4c8 100644 --- a/testgen/ui/views/dialogs/import_metadata_dialog.py +++ b/testgen/ui/views/dialogs/import_metadata_dialog.py @@ -4,9 +4,9 @@ import pandas as pd -from testgen.common.models.table_group import TableGroup from testgen.ui.queries.profiling_queries import TAG_FIELDS from testgen.ui.services.database_service import execute_db_query, fetch_all_from_db +from testgen.ui.services.query_cache import get_table_group from testgen.ui.session import session LOG = logging.getLogger("testgen") @@ -192,7 +192,7 @@ def _match_and_validate( matched_columns = sum(1 for r in preview_rows if r.get("column_name") and r.get("_status") in _importable) skipped = sum(1 for r in preview_rows if r.get("_status") not in _importable) - table_group = TableGroup.get(table_group_id) + table_group = get_table_group(table_group_id) return { "table_rows": table_rows, @@ -328,7 +328,7 @@ def apply_metadata_import(preview: dict, table_group_id: str | None = None) -> d def _disable_autoflags(table_group_id: str, metadata_columns: list[str]) -> None: - table_group = TableGroup.get(table_group_id) + table_group = get_table_group(table_group_id) changed = False if "critical_data_element" in metadata_columns and table_group.profile_flag_cdes: table_group.profile_flag_cdes = False diff --git a/testgen/ui/views/dialogs/run_tests_dialog.py b/testgen/ui/views/dialogs/run_tests_dialog.py index 33797485..0c4a9d72 100644 --- a/testgen/ui/views/dialogs/run_tests_dialog.py +++ b/testgen/ui/views/dialogs/run_tests_dialog.py @@ -6,7 +6,7 @@ from testgen.common.models.test_suite import TestSuite from testgen.ui.components import widgets as testgen from testgen.ui.navigation.router import Router -from testgen.ui.services.query_cache import get_test_run_summaries +from testgen.ui.services.query_cache import get_test_run_summaries, select_test_suites_minimal_where from testgen.ui.session import session LINK_HREF = "test-runs" @@ -19,7 +19,7 @@ def run_tests_dialog_widget( on_close: callable, test_suite_id: str | None = None, ) -> None: - test_suites = TestSuite.select_minimal_where( + test_suites = select_test_suites_minimal_where( TestSuite.project_code == project_code, TestSuite.is_monitor.isnot(True), ) diff --git a/testgen/ui/views/hygiene_issues.py b/testgen/ui/views/hygiene_issues.py index 50dbcf50..74bd024f 100644 --- a/testgen/ui/views/hygiene_issues.py +++ b/testgen/ui/views/hygiene_issues.py @@ -10,7 +10,6 @@ from testgen.common.mixpanel_service import MixpanelService from testgen.common.models import with_database_session from testgen.common.models.hygiene_issue import HygieneIssue -from testgen.common.models.profiling_run import ProfilingRun from testgen.common.pii_masking import get_pii_columns, mask_hygiene_detail, mask_profiling_pii from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets.download_dialog import ( @@ -26,6 +25,7 @@ from testgen.ui.queries.profiling_queries import get_profiling_anomalies from testgen.ui.queries.source_data_queries import get_hygiene_issue_source_data, get_hygiene_issue_source_query from testgen.ui.services.database_service import execute_db_query +from testgen.ui.services.query_cache import get_profiling_run_minimal from testgen.ui.session import session from testgen.utils import friendly_score, make_json_safe @@ -92,7 +92,7 @@ def render( sort: str | None = None, **_kwargs, ) -> None: - run = ProfilingRun.get_minimal(run_id) + run = get_profiling_run_minimal(run_id) if not run: self.router.navigate_with_warning( f"Profiling run with ID '{run_id}' does not exist. Redirecting to list of Profiling Runs ...", diff --git a/testgen/ui/views/monitors_dashboard.py b/testgen/ui/views/monitors_dashboard.py index 1b6becb9..e4fd3719 100644 --- a/testgen/ui/views/monitors_dashboard.py +++ b/testgen/ui/views/monitors_dashboard.py @@ -25,7 +25,16 @@ from testgen.ui.navigation.router import Router from testgen.ui.queries.profiling_queries import get_tables_by_table_group from testgen.ui.services.database_service import execute_db_query, fetch_all_from_db, fetch_one_from_db -from testgen.ui.services.query_cache import get_monitor_schedule, get_project_summary, get_test_type_summaries +from testgen.ui.services.query_cache import ( + get_monitor_schedule, + get_project_summary, + get_table_group, + get_test_definition, + get_test_suite, + get_test_type_summaries, + select_table_groups_minimal_where, + select_test_definitions_where, +) from testgen.ui.services.rerun_service import safe_rerun from testgen.ui.session import session, temp_value from testgen.ui.utils import dict_from_kv, get_cron_sample_handler @@ -86,7 +95,7 @@ def render( ) project_summary = get_project_summary(project_code) - table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code) + table_groups = select_table_groups_minimal_where(TableGroup.project_code == project_code) if not table_group_id or table_group_id not in [ str(item.id) for item in table_groups ]: table_group_id = str(table_groups[0].id) if table_groups else None @@ -595,7 +604,7 @@ def build_edit_monitor_settings_data( monitor_suite_id = table_group.monitor_test_suite_id if monitor_suite_id: - monitor_suite = TestSuite.get(monitor_suite_id) + monitor_suite = get_test_suite(monitor_suite_id) else: monitor_suite = TestSuite( project_code=table_group.project_code, @@ -645,7 +654,7 @@ def on_save_settings_clicked(payload: dict) -> None: JobSchedule.update_active(schedule.id, new_schedule_config["active"]) if is_new: - updated_table_group = TableGroup.get(table_group.id) + updated_table_group = get_table_group(table_group.id) updated_table_group.monitor_test_suite_id = monitor_suite.id updated_table_group.save() # Commit needed to make test suite visible to run_monitor_generation's separate DB connection @@ -677,7 +686,7 @@ def on_save_settings_clicked(payload: dict) -> None: @with_database_session def delete_monitor_suite(table_group: TableGroupMinimal) -> None: try: - monitor_suite = TestSuite.get(table_group.monitor_test_suite_id) + monitor_suite = get_test_suite(table_group.monitor_test_suite_id) TestSuite.cascade_delete([monitor_suite.id]) st.cache_data.clear() except Exception: @@ -746,7 +755,7 @@ def on_close_trends(_payload=None): lookback_multiplier = 3 if extended_history else 1 events = get_monitor_events_for_table(table_group.monitor_test_suite_id, table_name, lookback_multiplier) - definitions = TestDefinition.select_where( + definitions = select_test_definitions_where( TestDefinition.test_suite_id == table_group.monitor_test_suite_id, TestDefinition.table_name == table_name, TestDefinition.test_type.in_(["Freshness_Trend", "Volume_Trend", "Metric_Trend"]), @@ -754,7 +763,7 @@ def on_close_trends(_payload=None): predictions = {} if len(definitions) > 0: - test_suite = TestSuite.get(table_group.monitor_test_suite_id) + test_suite = get_test_suite(table_group.monitor_test_suite_id) monitor_schedule = get_monitor_schedule(table_group.monitor_test_suite_id) monitor_lookback = test_suite.monitor_lookback predict_sensitivity = test_suite.predict_sensitivity or PredictSensitivity.medium @@ -1016,7 +1025,7 @@ def build_edit_table_monitors_data( table_group: TableGroupMinimal, payload: dict, dialog: dict | None = None, ) -> tuple[dict, dict]: table_name = payload.get("table_name") - definitions = TestDefinition.select_where( + definitions = select_test_definitions_where( TestDefinition.test_suite_id == table_group.monitor_test_suite_id, TestDefinition.table_name == table_name, TestDefinition.test_type.in_(["Freshness_Trend", "Volume_Trend", "Metric_Trend"]), @@ -1040,7 +1049,7 @@ def on_save_test_definition(payload: dict) -> None: valid_columns = {col.name for col in TestDefinition.__table__.columns} for updated_def in get_updated_definitions(): - current_def: TestDefinitionSummary = TestDefinition.get(updated_def.get("id")) + current_def: TestDefinitionSummary = get_test_definition(updated_def.get("id")) if current_def: merged = {key: getattr(current_def, key, None) for key in valid_columns} merged.update({key: value for key, value in updated_def.items() if key in valid_columns}) diff --git a/testgen/ui/views/profiling_results.py b/testgen/ui/views/profiling_results.py index d777d297..1c5cdd7c 100644 --- a/testgen/ui/views/profiling_results.py +++ b/testgen/ui/views/profiling_results.py @@ -8,7 +8,6 @@ from testgen.common import date_service from testgen.common.date_service import parse_fuzzy_date from testgen.common.models import with_database_session -from testgen.common.models.profiling_run import ProfilingRun from testgen.common.pii_masking import ( PII_REDACTED, get_pii_columns, @@ -26,6 +25,7 @@ ) from testgen.ui.navigation.page import Page from testgen.ui.navigation.router import Router +from testgen.ui.services.query_cache import get_profiling_run_minimal from testgen.ui.session import session from testgen.ui.views.data_catalog import get_preview_data from testgen.utils import make_json_safe @@ -87,7 +87,7 @@ def render( sort: str | None = None, **_kwargs, ) -> None: - run = ProfilingRun.get_minimal(run_id) + run = get_profiling_run_minimal(run_id) if not run: self.router.navigate_with_warning( f"Profiling run with ID '{run_id}' does not exist. Redirecting to list of Profiling Runs ...", diff --git a/testgen/ui/views/profiling_runs.py b/testgen/ui/views/profiling_runs.py index af705c6f..5707838a 100644 --- a/testgen/ui/views/profiling_runs.py +++ b/testgen/ui/views/profiling_runs.py @@ -27,7 +27,13 @@ from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page from testgen.ui.navigation.router import Router -from testgen.ui.services.query_cache import get_profiling_run_summaries, get_project_summary, get_table_group_stats +from testgen.ui.services.query_cache import ( + get_profiling_run_summaries, + get_project_summary, + get_table_group_stats, + select_profiling_runs_where, + select_table_groups_minimal_where, +) from testgen.ui.session import session from testgen.ui.views.dialogs.manage_notifications import NotificationSettingsDialogBase from testgen.ui.views.dialogs.manage_schedules import ScheduleDialog @@ -62,7 +68,7 @@ def render(self, project_code: str, table_group_id: str | None = None, **_kwargs with st.spinner("Loading data ..."): project_summary = get_project_summary(project_code) profiling_runs, total_count = get_profiling_run_summaries(project_code, table_group_id, page=page) - table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code) + table_groups = select_table_groups_minimal_where(TableGroup.project_code == project_code) schedule_obj = ProfilingScheduleDialog(project_code) ns_obj = ProfilingRunNotificationSettingsDialog( @@ -225,7 +231,7 @@ class ProfilingScheduleDialog(ScheduleDialog): table_groups: Iterable[TableGroupMinimal] | None = None def init(self) -> None: - self.table_groups = TableGroup.select_minimal_where(TableGroup.project_code == self.project_code) + self.table_groups = select_table_groups_minimal_where(TableGroup.project_code == self.project_code) def get_arg_value(self, job): return next(item.table_groups_name for item in self.table_groups if str(item.id) == job.kwargs["table_group_id"]) @@ -259,7 +265,7 @@ def _model_to_item_attrs(self, model: ProfilingRunNotificationSettings) -> dict[ def _get_component_props(self) -> dict[str, typing.Any]: table_group_options = [ (str(tg.id), tg.table_groups_name) - for tg in TableGroup.select_minimal_where(TableGroup.project_code == self.ns_attrs["project_code"]) + for tg in select_table_groups_minimal_where(TableGroup.project_code == self.ns_attrs["project_code"]) ] table_group_options.insert(0, (None, "All Table Groups")) trigger_labels = { @@ -301,7 +307,7 @@ def on_delete_runs(job_execution_ids: list[str]) -> None: continue if job_exec.status in (JobStatus.PENDING, JobStatus.CLAIMED, JobStatus.RUNNING, JobStatus.CANCEL_REQUESTED): job_exec.request_cancel() - profiling_run = next(iter(ProfilingRun.select_where(ProfilingRun.job_execution_id == je_id)), None) + profiling_run = next(iter(select_profiling_runs_where(ProfilingRun.job_execution_id == je_id)), None) if profiling_run: ProfilingRun.cascade_delete([str(profiling_run.id)]) get_current_session().delete(job_exec) diff --git a/testgen/ui/views/project_settings.py b/testgen/ui/views/project_settings.py index db0b5011..7cc37900 100644 --- a/testgen/ui/views/project_settings.py +++ b/testgen/ui/views/project_settings.py @@ -12,6 +12,7 @@ from testgen.ui.components import widgets as testgen from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page +from testgen.ui.services.query_cache import get_project, select_projects_where from testgen.ui.session import session, temp_value PAGE_TITLE = "Project Settings" @@ -35,7 +36,7 @@ class ProjectSettingsPage(Page): existing_names: list[str] | None = None def render(self, project_code: str | None = None, **_kwargs) -> None: - self.project = Project.get(project_code) + self.project = get_project(project_code) testgen.page_header( PAGE_TITLE, @@ -64,7 +65,7 @@ def on_observability_connection_test(payload: dict) -> None: @with_database_session def update_project(self, project_code: str, edited_project: dict) -> None: existing_names = [ - p.project_name.lower() for p in Project.select_where(Project.project_code != project_code) + p.project_name.lower() for p in select_projects_where(Project.project_code != project_code) ] new_project_name = edited_project["name"] if new_project_name.lower() in existing_names: diff --git a/testgen/ui/views/table_groups.py b/testgen/ui/views/table_groups.py index 2b909117..bb9e186c 100644 --- a/testgen/ui/views/table_groups.py +++ b/testgen/ui/views/table_groups.py @@ -22,7 +22,16 @@ from testgen.ui.navigation.page import Page from testgen.ui.navigation.router import Router from testgen.ui.queries import table_group_queries -from testgen.ui.services.query_cache import get_profiling_run_summaries, get_project_summary, get_table_group_stats +from testgen.ui.services.query_cache import ( + get_connection_minimal, + get_profiling_run_summaries, + get_project_summary, + get_table_group, + get_table_group_minimal, + get_table_group_stats, + select_connections_minimal_where, + select_table_groups_minimal_where, +) from testgen.ui.services.rerun_service import safe_rerun from testgen.ui.session import session, temp_value from testgen.ui.utils import get_cron_sample_handler @@ -74,7 +83,7 @@ def render( if table_group_name: table_group_filters.append(TableGroup.table_groups_name.ilike(f"%{table_group_name}%")) - table_groups = TableGroup.select_minimal_where(*table_group_filters) + table_groups = select_table_groups_minimal_where(*table_group_filters) connections = self._get_connections(project_code) wizard_mode = st.session_state.get("tg_wizard_mode") @@ -282,7 +291,7 @@ def on_save_table_group_clicked(payload: dict): set_run_profiling(run_profiling) def on_close_clicked(_params: dict) -> None: - TableGroup.select_minimal_where.clear() + select_table_groups_minimal_where.clear() for key in ["tg_wizard_mode", "tg_wizard_connection_id", "tg_wizard_table_group_id"]: st.session_state.pop(key, None) @@ -328,7 +337,7 @@ def on_close_clicked(_params: dict) -> None: table_group = TableGroup(project_code=project_code) original_table_group_schema = None if table_group_id: - table_group = TableGroup.get(table_group_id) + table_group = get_table_group(table_group_id) original_table_group_schema = table_group.table_group_schema is_table_group_used = TableGroup.is_in_use([table_group_id]) @@ -525,7 +534,7 @@ def on_close_edit(_params: dict) -> None: for key in ["tg_wizard_mode", "tg_wizard_table_group_id"]: st.session_state.pop(key, None) - table_group = TableGroup.get(table_group_id) + table_group = get_table_group(table_group_id) original_schema = table_group.table_group_schema is_in_use = TableGroup.is_in_use([table_group_id]) @@ -592,9 +601,9 @@ def on_close_edit(_params: dict) -> None: def _get_connections(self, project_code: str, connection_id: str | None = None) -> list[dict]: if connection_id: - connections = [Connection.get_minimal(connection_id)] + connections = [get_connection_minimal(connection_id)] else: - connections = Connection.select_minimal_where(Connection.project_code == project_code) + connections = select_connections_minimal_where(Connection.project_code == project_code) return [ format_connection(connection) for connection in connections ] def _format_table_group_list( @@ -622,7 +631,7 @@ def _format_table_group_list( @with_database_session def _prepare_delete_dialog(self, table_group_id: str) -> None: - table_group = TableGroup.get_minimal(table_group_id) + table_group = get_table_group_minimal(table_group_id) can_be_deleted = not TableGroup.is_in_use([table_group_id]) st.session_state["tg_delete_dialog"] = { "open": True, @@ -635,7 +644,7 @@ def _execute_delete(self, table_group_id: str) -> None: table_group_name = st.session_state.get("tg_delete_dialog", {}).get("table_group", {}).get("table_groups_name", "") if not (ProfilingRun.has_active_job_for(TableGroup, table_group_id) or TestRun.has_active_job_for(TableGroup, table_group_id)): TableGroup.cascade_delete([table_group_id]) - TableGroup.select_minimal_where.clear() + select_table_groups_minimal_where.clear() st.toast(f"Table Group {table_group_name} has been deleted.", icon=":material/check:") else: st.toast("This Table Group is in use by a running process and cannot be deleted.", icon=":material/error:") diff --git a/testgen/ui/views/test_definitions.py b/testgen/ui/views/test_definitions.py index 05f3390a..4ba66240 100644 --- a/testgen/ui/views/test_definitions.py +++ b/testgen/ui/views/test_definitions.py @@ -12,7 +12,6 @@ from testgen.common.database.database_service import get_flavor_service from testgen.common.enums import JobSource from testgen.common.models import with_database_session -from testgen.common.models.connection import Connection from testgen.common.models.job_execution import JobExecution from testgen.common.models.table_group import TableGroup, TableGroupMinimal from testgen.common.models.test_definition import ( @@ -35,6 +34,16 @@ from testgen.ui.navigation.router import Router from testgen.ui.queries import profiling_queries from testgen.ui.services.database_service import fetch_all_from_db, fetch_df_from_db, fetch_from_target_db +from testgen.ui.services.query_cache import ( + get_connection, + get_table_group_minimal, + get_test_suite, + select_table_groups_minimal_where, + select_test_definitions_minimal_where, + select_test_definitions_page, + select_test_definitions_where, + select_test_suites_minimal_where, +) from testgen.ui.session import session from testgen.utils import make_json_safe, to_dataframe @@ -107,7 +116,7 @@ def render( sort: str | None = None, **_kwargs, ) -> None: - test_suite = TestSuite.get(test_suite_id) + test_suite = get_test_suite(test_suite_id) if not test_suite: self.router.navigate_with_warning( f"Test suite with ID '{test_suite_id}' does not exist. Redirecting to list of Test Suites ...", @@ -115,7 +124,7 @@ def render( ) return - table_group = TableGroup.get_minimal(test_suite.table_groups_id) + table_group = get_table_group_minimal(test_suite.table_groups_id) project_code = table_group.project_code if not session.auth.user_has_project_access(project_code): @@ -149,8 +158,8 @@ def render( test_types = run_test_type_lookup_query().to_dict("records") table_columns = get_columns(str(table_group.id)) filter_columns_df = get_test_suite_columns(test_suite_id) - table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code) - all_test_suites = TestSuite.select_minimal_where( + table_groups = select_table_groups_minimal_where(TableGroup.project_code == project_code) + all_test_suites = select_test_suites_minimal_where( TestSuite.table_groups_id.in_([str(tg.id) for tg in table_groups]), TestSuite.is_monitor.isnot(True), ) @@ -184,7 +193,7 @@ def render( qualifies_table_refs_with_schema = True if st.session_state.get(TD_ADD_DIALOG_KEY) or st.session_state.get(TD_EDIT_DIALOG_KEY): - connection = Connection.get(table_group.connection_id) + connection = get_connection(table_group.connection_id) if connection: qualifies_table_refs_with_schema = get_flavor_service(connection.sql_flavor).qualifies_table_refs_with_schema @@ -296,7 +305,7 @@ def on_unlock_all_opened(*_) -> None: def on_copy_move_dialog_opened(selected) -> None: if selected == "all": all_ids = get_test_definition_ids(test_suite, table_name, column_name, test_type, flagged_filter=flagged) - results = TestDefinition.select_where(TestDefinition.id.in_(all_ids)) + results = select_test_definitions_where(TestDefinition.id.in_(all_ids)) selected = [ {"id": str(r.id), "table_name": r.table_name, "column_name": r.column_name, "test_type": r.test_type, "lock_refresh": r.lock_refresh} @@ -799,7 +808,7 @@ def run_test_type_lookup_query(test_type: str | None = None) -> pd.DataFrame: @st.cache_data(show_spinner=False) def get_test_suite_columns(test_suite_id: str) -> pd.DataFrame: - results = TestDefinition.select_minimal_where( + results = select_test_definitions_minimal_where( TestDefinition.test_suite_id == test_suite_id, order_by=(asc(func.lower(TestDefinition.table_name)), asc(func.lower(TestDefinition.column_name))), ) @@ -819,8 +828,8 @@ def get_test_definitions( """Return ``(df, total_count)`` for test definitions matching the given filters. When ``page_index`` is provided (0-based), fetches only that page from - the DB using ``TestDefinition.select_page()``; otherwise fetches all rows - via ``select_where()``. ``total_count`` is always the full matching count. + the DB using ``select_test_definitions_page()``; otherwise fetches all rows + via ``select_test_definitions_where()``. ``total_count`` is always the full matching count. """ clauses = [TestDefinition.test_suite_id == test_suite.id] if table_name: @@ -852,14 +861,14 @@ def get_test_definitions( order_by_tuple = tuple(order_by) if order_by else None if page_index is not None: - test_definitions, total_count = TestDefinition.select_page( + test_definitions, total_count = select_test_definitions_page( *clauses, order_by=order_by_tuple, page=page_index + 1, limit=page_size, ) else: - test_definitions = TestDefinition.select_where(*clauses, order_by=order_by_tuple) + test_definitions = select_test_definitions_where(*clauses, order_by=order_by_tuple) total_count = len(test_definitions) df = to_dataframe(test_definitions, TestDefinitionSummary.columns()) @@ -913,7 +922,7 @@ def get_test_definition_ids( clauses.append(TestDefinition.flagged == True) elif flagged_filter == "Not Flagged": clauses.append(TestDefinition.flagged == False) - results = TestDefinition.select_where(*clauses) + results = select_test_definitions_where(*clauses) return [str(r.id) for r in results] @@ -934,7 +943,7 @@ def get_test_definitions_collision( for item in test_definitions if item["column_name"] is not None ] - results = TestDefinition.select_minimal_where( + results = select_test_definitions_minimal_where( TestDefinition.table_groups_id == target_table_group_id, TestDefinition.test_suite_id == target_test_suite_id, TestDefinition.last_auto_gen_date.isnot(None), @@ -965,7 +974,7 @@ def get_columns(table_groups_id: str) -> list[dict]: def validate_test(test_definition: dict, table_group: TableGroupMinimal) -> None: schema = test_definition["schema_name"] table_name = test_definition["table_name"] - connection = Connection.get(table_group.connection_id) + connection = get_connection(table_group.connection_id) if test_definition["test_type"] == "Condition_Flag": condition = test_definition["custom_query"] diff --git a/testgen/ui/views/test_results.py b/testgen/ui/views/test_results.py index 014e4182..9c26ea21 100644 --- a/testgen/ui/views/test_results.py +++ b/testgen/ui/views/test_results.py @@ -11,10 +11,8 @@ from testgen.common import date_service from testgen.common.mixpanel_service import MixpanelService from testgen.common.models import with_database_session -from testgen.common.models.table_group import TableGroup from testgen.common.models.test_definition import TestDefinition, TestDefinitionNote, TestDefinitionSummary -from testgen.common.models.test_run import TestRun -from testgen.common.models.test_suite import TestSuite, TestSuiteMinimal +from testgen.common.models.test_suite import TestSuiteMinimal from testgen.common.pii_masking import get_pii_columns, mask_profiling_pii from testgen.ui.components import widgets as testgen from testgen.ui.components.widgets.download_dialog import ( @@ -35,6 +33,14 @@ get_test_issue_source_query_custom, ) from testgen.ui.services.database_service import execute_db_query, fetch_df_from_db, fetch_one_from_db +from testgen.ui.services.query_cache import ( + get_table_group_minimal, + get_test_definition, + get_test_run_minimal, + get_test_suite, + get_test_suite_minimal, + select_test_definitions_where, +) from testgen.ui.services.string_service import snake_case_to_title_case from testgen.ui.session import session from testgen.utils import friendly_score, make_json_safe @@ -134,7 +140,7 @@ def render( sort: str | None = None, **_kwargs, ) -> None: - run = TestRun.get_minimal(run_id) + run = get_test_run_minimal(run_id) if not run: self.router.navigate_with_warning( f"Test run with ID '{run_id}' does not exist. Redirecting to list of Test Runs ...", @@ -166,7 +172,7 @@ def render( # Handle deferred export/issue report (still use st.dialog for file downloads) export_filters = st.session_state.pop(EXPORT_FILTERS_KEY, None) if export_filters is not None: - test_suite = TestSuite.get_minimal(run.test_suite_id) + test_suite = get_test_suite_minimal(run.test_suite_id) _handle_export(export_filters, run_id, run_date, test_suite) issue_report_data = st.session_state.pop(ISSUE_REPORT_KEY, None) @@ -216,7 +222,7 @@ def render( filter_options = test_result_queries.get_filter_options(run_id) - test_suite = TestSuite.get_minimal(run.test_suite_id) + test_suite = get_test_suite_minimal(run.test_suite_id) items = json.loads(df.to_json(orient="records", date_unit="s")) summary = get_test_result_summary(run_id) @@ -412,7 +418,7 @@ def on_edit_test_saved(test_def: dict) -> None: def on_validate_test(test_def: dict) -> None: from testgen.ui.views.test_definitions import validate_test - table_group = TableGroup.get_minimal(test_suite.table_groups_id) + table_group = get_table_group_minimal(test_suite.table_groups_id) try: validate_test(test_def, table_group) st.session_state[VALIDATE_RESULT_KEY] = {"success": True, "message": "Validation is successful."} @@ -550,12 +556,12 @@ def _build_edit_test_dialog_data(test_definition_id: str | None, test_suite_mini from testgen.ui.views.test_definitions import get_columns, run_test_type_lookup_query - test_def = TestDefinition.select_where(TestDefinition.id == test_definition_id) + test_def = select_test_definitions_where(TestDefinition.id == test_definition_id) if not test_def: return None - full_test_suite = TestSuite.get(test_suite_minimal.id) - table_group = TableGroup.get_minimal(test_suite_minimal.table_groups_id) + full_test_suite = get_test_suite(test_suite_minimal.id) + table_group = get_table_group_minimal(test_suite_minimal.table_groups_id) test_def_row = test_def[0] test_def_dict = {col: getattr(test_def_row, col) for col in TestDefinitionSummary.columns()} for key in ["id", "table_groups_id", "profile_run_id", "test_suite_id"]: @@ -672,7 +678,7 @@ def readable_boolean(v: bool) -> str: if not test_definition_id: return None - test_definition = TestDefinition.get(test_definition_id) + test_definition = get_test_definition(test_definition_id) if not test_definition: return None @@ -730,8 +736,7 @@ def readable_boolean(v: bool) -> str: def _handle_export(export_filters: dict, run_id: str, run_date: str, test_suite: TestSuiteMinimal) -> None: - from testgen.common.models.table_group import TableGroup - table_group = TableGroup.get_minimal(test_suite.table_groups_id) + table_group = get_table_group_minimal(test_suite.table_groups_id) export_type = export_filters.get("type", "all") with st.spinner("Loading data ..."): diff --git a/testgen/ui/views/test_runs.py b/testgen/ui/views/test_runs.py index 2d490069..f6563b94 100644 --- a/testgen/ui/views/test_runs.py +++ b/testgen/ui/views/test_runs.py @@ -21,7 +21,13 @@ from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page from testgen.ui.navigation.router import Router -from testgen.ui.services.query_cache import get_project_summary, get_test_run_summaries +from testgen.ui.services.query_cache import ( + get_project_summary, + get_test_run_summaries, + select_table_groups_minimal_where, + select_test_runs_where, + select_test_suites_minimal_where, +) from testgen.ui.session import session from testgen.ui.views.dialogs.manage_notifications import NotificationSettingsDialogBase from testgen.ui.views.dialogs.manage_schedules import ScheduleDialog @@ -61,8 +67,8 @@ def render(self, project_code: str, table_group_id: str | None = None, test_suit with st.spinner("Loading data ..."): project_summary = get_project_summary(project_code) test_runs, total_count = get_test_run_summaries(project_code, table_group_id, test_suite_id, page=page) - table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code) - test_suites = TestSuite.select_minimal_where(TestSuite.project_code == project_code, TestSuite.is_monitor.isnot(True)) + table_groups = select_table_groups_minimal_where(TableGroup.project_code == project_code) + test_suites = select_test_suites_minimal_where(TestSuite.project_code == project_code, TestSuite.is_monitor.isnot(True)) def on_run_tests_clicked(*_) -> None: st.session_state[TR_RUN_TESTS_DIALOG_KEY] = True @@ -240,7 +246,7 @@ def _model_to_item_attrs(self, model: TestRunNotificationSettings) -> dict[str, def _get_component_props(self) -> dict[str, Any]: test_suite_options = [ (str(ts.id), ts.test_suite) - for ts in TestSuite.select_minimal_where( + for ts in select_test_suites_minimal_where( TestSuite.project_code == self.ns_attrs["project_code"], TestSuite.is_monitor.isnot(True), ) @@ -268,7 +274,7 @@ class TestRunScheduleDialog(ScheduleDialog): test_suites: Iterable[TestSuiteMinimal] | None = None def init(self) -> None: - self.test_suites = TestSuite.select_minimal_where( + self.test_suites = select_test_suites_minimal_where( TestSuite.project_code == self.project_code, TestSuite.is_monitor.isnot(True), ) @@ -313,7 +319,7 @@ def on_delete_runs(job_execution_ids: list[str]) -> None: continue if job_exec.status in (JobStatus.PENDING, JobStatus.CLAIMED, JobStatus.RUNNING, JobStatus.CANCEL_REQUESTED): job_exec.request_cancel() - test_run = next(iter(TestRun.select_where(TestRun.job_execution_id == je_id)), None) + test_run = next(iter(select_test_runs_where(TestRun.job_execution_id == je_id)), None) if test_run: TestRun.cascade_delete([str(test_run.id)]) get_current_session().delete(job_exec) diff --git a/testgen/ui/views/test_suites.py b/testgen/ui/views/test_suites.py index c32c1b88..d1797f02 100644 --- a/testgen/ui/views/test_suites.py +++ b/testgen/ui/views/test_suites.py @@ -15,7 +15,14 @@ from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page from testgen.ui.navigation.router import Router -from testgen.ui.services.query_cache import get_project_summary, get_test_suite_summaries +from testgen.ui.services.query_cache import ( + get_project_summary, + get_table_group, + get_test_suite, + get_test_suite_minimal, + get_test_suite_summaries, + select_table_groups_minimal_where, +) from testgen.ui.session import session from testgen.ui.views.dialogs.generate_tests_dialog import ( get_generation_set_choices, @@ -57,7 +64,7 @@ def render(self, project_code: str, table_group_id: str | None = None, test_suit "manage-test-suites", ) - table_groups = TableGroup.select_minimal_where(TableGroup.project_code == project_code) + table_groups = select_table_groups_minimal_where(TableGroup.project_code == project_code) user_can_edit = session.auth.user_has_permission("edit") test_suites = get_test_suite_summaries(project_code, table_group_id, test_suite_name) project_summary = get_project_summary(project_code) @@ -77,7 +84,7 @@ def render(self, project_code: str, table_group_id: str | None = None, test_suit "result": st.session_state.get("ts_form_dialog:result"), } elif edit_ts_id := st.session_state.get(EDIT_DIALOG_KEY): - selected = TestSuite.get(edit_ts_id) + selected = get_test_suite(edit_ts_id) form_dialog = { "open": True, "mode": "edit", @@ -133,7 +140,7 @@ def on_run_notifications_clicked(*_) -> None: generate_tests_data = None if generate_tests_ts_id := st.session_state.get(GENERATE_TESTS_DIALOG_KEY): - generate_ts = TestSuite.get_minimal(generate_tests_ts_id) + generate_ts = get_test_suite_minimal(generate_tests_ts_id) generation_sets = get_generation_set_choices() default_set = "Standard" if "Standard" in generation_sets else (generation_sets[0] if generation_sets else "") test_ct, unlocked_test_ct, unlocked_edits_ct = get_test_suite_refresh_warning(str(generate_ts.id)) @@ -315,7 +322,7 @@ def save_test_suite_form(data: dict) -> None: if mode == "edit": test_suite_id = data.get("test_suite_id") - test_suite = TestSuite.get(test_suite_id) + test_suite = get_test_suite(test_suite_id) test_suite.test_suite_description = data.get("test_suite_description", "") test_suite.severity = data.get("severity") test_suite.export_to_observability = data.get("export_to_observability", False) @@ -329,7 +336,7 @@ def save_test_suite_form(data: dict) -> None: get_test_suite_summaries.clear() st.session_state[PAGE_RESULT_KEY] = {"success": True, "message": "Changes have been saved successfully."} else: - table_group = TableGroup.get(data.get("table_groups_id")) + table_group = get_table_group(data.get("table_groups_id")) test_suite = TestSuite() test_suite.project_code = table_group.project_code test_suite.test_suite = data.get("test_suite") @@ -351,7 +358,7 @@ def save_test_suite_form(data: dict) -> None: @with_database_session def prepare_ts_delete_dialog(test_suite_id: str) -> None: - selected = TestSuite.get_minimal(test_suite_id) + selected = get_test_suite_minimal(test_suite_id) is_in_use = TestSuite.is_in_use([selected.id]) st.session_state["ts_delete_dialog"] = { "open": True, @@ -375,7 +382,7 @@ def execute_ts_delete(test_suite_id: str) -> None: @with_database_session def observability_export_action(test_suite_id: str) -> None: - selected_test_suite = TestSuite.get_minimal(test_suite_id) + selected_test_suite = get_test_suite_minimal(test_suite_id) try: qty_of_exported_events = export_test_results(selected_test_suite.id) st.session_state[PAGE_RESULT_KEY] = {"success": True, "message": f"Export finished: {qty_of_exported_events} events exported."} diff --git a/tests/unit/common/models/test_test_definition.py b/tests/unit/common/models/test_test_definition.py index b7cd9431..f733d1b6 100644 --- a/tests/unit/common/models/test_test_definition.py +++ b/tests/unit/common/models/test_test_definition.py @@ -246,12 +246,6 @@ def test_severity_enum_value_accepted(): # --- select_page --- -@pytest.fixture(autouse=True) -def clear_select_page_cache(): - TestDefinition.select_page.clear() - yield - - def _make_summary_row(table_name: str = "my_table") -> dict: return { "id": uuid4(), diff --git a/tests/unit/common/test_no_streamlit_in_common.py b/tests/unit/common/test_no_streamlit_in_common.py new file mode 100644 index 00000000..4f6c9c78 --- /dev/null +++ b/tests/unit/common/test_no_streamlit_in_common.py @@ -0,0 +1,47 @@ +"""Boundary guard — `testgen/common/` must not import Streamlit or use its cache. + +Streamlit caches in-process even outside its runtime; a `@st.cache_data` decorator +on a shared model method leaks stale results into MCP, API, scheduler, and CLI +processes. Cache decorators belong in the UI layer (`testgen/ui/services/query_cache.py` +or view-local helpers), not in `common/`. + +Exception: ``streamlit_authenticator`` is a separately-packaged dependency unrelated +to this boundary; it's allowed. +""" + +from __future__ import annotations + +import re +from pathlib import Path + +import pytest + +import testgen.common as common_pkg + +COMMON_ROOT = Path(common_pkg.__file__).resolve().parent + +_BANNED_PATTERNS = [ + re.compile(r"^\s*import\s+streamlit\s*(?:as\s+\w+)?\s*(?:#.*)?$"), + re.compile(r"^\s*from\s+streamlit(?:\.|\s)"), + re.compile(r"@st\.cache_(data|resource)\b"), +] + + +def _python_files() -> list[Path]: + return sorted(p for p in COMMON_ROOT.rglob("*.py") if "__pycache__" not in p.parts) + + +@pytest.mark.parametrize("path", _python_files(), ids=lambda p: str(p.relative_to(COMMON_ROOT))) +def test_no_streamlit_or_cache_decorator(path: Path) -> None: + text = path.read_text(encoding="utf-8") + offending: list[tuple[int, str]] = [] + for lineno, line in enumerate(text.splitlines(), start=1): + for pattern in _BANNED_PATTERNS: + if pattern.search(line): + offending.append((lineno, line.rstrip())) + break + assert not offending, ( + f"{path.relative_to(COMMON_ROOT)} imports Streamlit or applies an " + f"@st.cache_* decorator. Caching belongs in the UI layer " + f"(testgen/ui/services/query_cache.py). Offending lines: {offending}" + ) diff --git a/tests/unit/ui/services/__init__.py b/tests/unit/ui/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/ui/services/test_query_cache.py b/tests/unit/ui/services/test_query_cache.py new file mode 100644 index 00000000..aefdcb60 --- /dev/null +++ b/tests/unit/ui/services/test_query_cache.py @@ -0,0 +1,60 @@ +"""Wiring tests for query_cache.py wrappers. + +Verifies that each cached UI wrapper exists, is callable, and exposes ``.clear()`` +for targeted cache invalidation. Does NOT exercise Streamlit cache logic itself. +""" + +from __future__ import annotations + +import pytest + +from testgen.ui.services import query_cache + +# Wrappers that replace cached calls to model methods after TG-1091. +# Names match the per-(entity, method) convention documented in the spec. +EXPECTED_WRAPPERS = [ + # Connection + "get_connection", + "select_connections_where", + "get_connection_minimal", + "select_connections_minimal_where", + # User + "get_user", + "select_users_where", + # TableGroup + "get_table_group", + "get_table_group_minimal", + "select_table_groups_minimal_where", + # TestSuite + "get_test_suite", + "get_test_suite_minimal", + "select_test_suites_minimal_where", + # TestRun + "get_test_run_minimal", + "select_test_runs_where", + # ProfilingRun + "get_profiling_run_minimal", + "select_profiling_runs_where", + "select_profiling_runs_minimal_where", + # TestDefinition + "get_test_definition", + "select_test_definitions_where", + "select_test_definitions_minimal_where", + "select_test_definitions_page", + # Project + "get_project", + "select_projects_where", + # ProjectMembership + "get_project_membership", + "select_project_memberships_where", +] + + +@pytest.mark.parametrize("name", EXPECTED_WRAPPERS) +def test_wrapper_exists_and_is_cached(name: str) -> None: + wrapper = getattr(query_cache, name, None) + assert wrapper is not None, f"Missing wrapper: query_cache.{name}" + assert callable(wrapper), f"Wrapper is not callable: query_cache.{name}" + assert hasattr(wrapper, "clear"), ( + f"Wrapper missing .clear() (cache decorator dropped?): query_cache.{name}" + ) diff --git a/tests/unit/ui/test_project_settings.py b/tests/unit/ui/test_project_settings.py index 89c0eeac..08ab933b 100644 --- a/tests/unit/ui/test_project_settings.py +++ b/tests/unit/ui/test_project_settings.py @@ -81,8 +81,8 @@ def test_update_project_raises_on_duplicate_name(mock_session): ] with ( - patch(f"{MODULE}.Project") as mock_project_cls, + patch(f"{MODULE}.select_projects_where") as mock_select, pytest.raises(ValueError, match="Other Project"), ): - mock_project_cls.select_where.return_value = [MagicMock(project_name="Other Project")] + mock_select.return_value = [MagicMock(project_name="Other Project")] page.update_project("proj", {"name": "Other Project", "use_dq_score_weights": True}) From 3e800d99879fd1d908435539333f5e4be57d9680 Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Tue, 26 May 2026 20:00:46 -0400 Subject: [PATCH 38/58] fix(common-models): get_previous returns self in TestRun and ProfilingRun TestRun.get_previous() and ProfilingRun.get_previous() filtered by JobExecution.started_at < self.test_starttime (or self.profiling_starttime). JE.started_at is recorded ~30ms before the TR/PR start columns of the same row, so the current row satisfied its own filter and ORDER BY started_at DESC LIMIT 1 returned it back. Single-arg callers of compare_test_runs and compare_profiling_runs (coming online with the new default in this branch and TG-1068 respectively) therefore got target == baseline. Switch both helpers to compare and order on the same column on both sides (TestRun.test_starttime / ProfilingRun.profiling_starttime). Co-Authored-By: Claude Opus 4.7 (1M context) --- testgen/common/models/profiling_run.py | 4 ++-- testgen/common/models/test_run.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/testgen/common/models/profiling_run.py b/testgen/common/models/profiling_run.py index 20e69a27..a252704c 100644 --- a/testgen/common/models/profiling_run.py +++ b/testgen/common/models/profiling_run.py @@ -419,9 +419,9 @@ def get_previous(self) -> Self | None: .where( ProfilingRun.table_groups_id == self.table_groups_id, JobExecution.status == JobStatus.COMPLETED, - JobExecution.started_at < self.profiling_starttime, + ProfilingRun.profiling_starttime < self.profiling_starttime, ) - .order_by(desc(JobExecution.started_at)) + .order_by(desc(ProfilingRun.profiling_starttime)) .limit(1) ) return get_current_session().scalar(query) diff --git a/testgen/common/models/test_run.py b/testgen/common/models/test_run.py index 27e52a8b..50b20d36 100644 --- a/testgen/common/models/test_run.py +++ b/testgen/common/models/test_run.py @@ -196,9 +196,9 @@ def get_previous(self) -> Self | None: .where( TestRun.test_suite_id == self.test_suite_id, JobExecution.status == JobStatus.COMPLETED, - JobExecution.started_at < self.test_starttime, + TestRun.test_starttime < self.test_starttime, ) - .order_by(desc(JobExecution.started_at)) + .order_by(desc(TestRun.test_starttime)) .limit(1) ) return get_current_session().scalar(query) From fe37c412ab28b83adb9da5c66803733fadba9968 Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Tue, 26 May 2026 20:00:54 -0400 Subject: [PATCH 39/58] feat(mcp): single-arg compare_test_runs (TG-1056) Rename get_test_run_diff to compare_test_runs and make baseline optional. Mirrors compare_profiling_runs: target_job_execution_id (required) + baseline_job_execution_id (optional, defaults to the immediately previous completed run on the same suite via TestRun.get_previous()). Reject in-progress/error target or baseline runs the same way compare_profiling_runs does, so partial result sets cannot be mistaken for removed tests. Rewrite the compare_runs prompt to route through the tool directly instead of fetching both runs' results via list_test_results and diffing manually. DiffRow and RunDiff field renames (_a/_b -> _baseline/_target) keep the data layer aligned with the rendered Target/Baseline output. Drive-by: drop the "pass baseline_job_execution_id to compare against" suggestion from the no-baseline error in both compare_test_runs and compare_profiling_runs. When the target is the oldest run on its parent, no earlier run exists; suggesting an action that can't help misleads. Co-Authored-By: Claude Opus 4.7 (1M context) --- testgen/common/models/test_result.py | 62 ++-- testgen/mcp/prompts/workflows.py | 16 +- testgen/mcp/server.py | 4 +- testgen/mcp/tools/profile_history.py | 4 +- testgen/mcp/tools/test_results.py | 136 +++++---- tests/unit/mcp/test_tools_profile_history.py | 2 +- tests/unit/mcp/test_tools_test_results.py | 284 ++++++++++++++----- 7 files changed, 342 insertions(+), 166 deletions(-) diff --git a/testgen/common/models/test_result.py b/testgen/common/models/test_result.py index 90538adf..c1d25e73 100644 --- a/testgen/common/models/test_result.py +++ b/testgen/common/models/test_result.py @@ -80,27 +80,27 @@ def failure_rate(self) -> float: @dataclass class DiffRow: - """One test definition's status across two runs for ``get_test_run_diff``.""" + """One test definition's status across two runs for ``compare_test_runs``.""" test_definition_id: UUID test_type: str test_name_short: str | None table_name: str | None column_names: str | None - status_a: TestResultStatus | None - status_b: TestResultStatus | None - measure_a: str | None - measure_b: str | None - threshold_a: str | None - threshold_b: str | None + status_baseline: TestResultStatus | None + status_target: TestResultStatus | None + measure_baseline: str | None + measure_target: str | None + threshold_baseline: str | None + threshold_target: str | None @dataclass class RunDiff: """Categorized diff between two test runs.""" - total_a: int - total_b: int + total_baseline: int + total_target: int regressions: list[DiffRow] = field(default_factory=list) improvements: list[DiffRow] = field(default_factory=list) persistent_failures: list[DiffRow] = field(default_factory=list) @@ -414,7 +414,7 @@ def failure_trend( ] @classmethod - def diff_with_details(cls, test_run_id_a: UUID, test_run_id_b: UUID) -> RunDiff: + def diff_with_details(cls, baseline_run_id: UUID, target_run_id: UUID) -> RunDiff: """Compare two runs by ``test_definition_id`` and return categorized diff rows.""" def _fetch(run_id: UUID) -> dict[UUID, dict]: @@ -448,41 +448,41 @@ def _fetch(run_id: UUID) -> dict[UUID, dict]: for row in get_current_session().execute(query) } - def _row(tid: UUID, info_a: dict | None, info_b: dict | None) -> DiffRow: - base = info_b or info_a # prefer B for display fields (test_type, table, column names) + def _row(tid: UUID, baseline_info: dict | None, target_info: dict | None) -> DiffRow: + base = target_info or baseline_info # prefer target for display fields (test_type, table, column names) return DiffRow( test_definition_id=tid, test_type=base["test_type"], test_name_short=base["test_name_short"], table_name=base["table_name"], column_names=base["column_names"], - status_a=info_a["status"] if info_a else None, - status_b=info_b["status"] if info_b else None, - measure_a=info_a["measure"] if info_a else None, - measure_b=info_b["measure"] if info_b else None, - threshold_a=info_a["threshold"] if info_a else None, - threshold_b=info_b["threshold"] if info_b else None, + status_baseline=baseline_info["status"] if baseline_info else None, + status_target=target_info["status"] if target_info else None, + measure_baseline=baseline_info["measure"] if baseline_info else None, + measure_target=target_info["measure"] if target_info else None, + threshold_baseline=baseline_info["threshold"] if baseline_info else None, + threshold_target=target_info["threshold"] if target_info else None, ) - results_a = _fetch(test_run_id_a) - results_b = _fetch(test_run_id_b) + baseline_results = _fetch(baseline_run_id) + target_results = _fetch(target_run_id) failing = {TestResultStatus.Failed, TestResultStatus.Warning} - diff = RunDiff(total_a=len(results_a), total_b=len(results_b)) + diff = RunDiff(total_baseline=len(baseline_results), total_target=len(target_results)) - for tid in results_a.keys() & results_b.keys(): - info_a, info_b = results_a[tid], results_b[tid] - row = _row(tid, info_a, info_b) - if info_a["status"] == TestResultStatus.Passed and info_b["status"] in failing: + for tid in baseline_results.keys() & target_results.keys(): + baseline_info, target_info = baseline_results[tid], target_results[tid] + row = _row(tid, baseline_info, target_info) + if baseline_info["status"] == TestResultStatus.Passed and target_info["status"] in failing: diff.regressions.append(row) - elif info_a["status"] in failing and info_b["status"] == TestResultStatus.Passed: + elif baseline_info["status"] in failing and target_info["status"] == TestResultStatus.Passed: diff.improvements.append(row) - elif info_a["status"] in failing and info_b["status"] in failing: + elif baseline_info["status"] in failing and target_info["status"] in failing: diff.persistent_failures.append(row) - for tid in results_b.keys() - results_a.keys(): - diff.new_tests.append(_row(tid, None, results_b[tid])) + for tid in target_results.keys() - baseline_results.keys(): + diff.new_tests.append(_row(tid, None, target_results[tid])) - for tid in results_a.keys() - results_b.keys(): - diff.removed_tests.append(_row(tid, results_a[tid], None)) + for tid in baseline_results.keys() - target_results.keys(): + diff.removed_tests.append(_row(tid, baseline_results[tid], None)) return diff diff --git a/testgen/mcp/prompts/workflows.py b/testgen/mcp/prompts/workflows.py index bf90493c..55d833f6 100644 --- a/testgen/mcp/prompts/workflows.py +++ b/testgen/mcp/prompts/workflows.py @@ -112,7 +112,7 @@ def hygiene_triage(table_group_id: str | None = None) -> str: def compare_runs(test_suite: str | None = None) -> str: - """Compare the two most recent test runs to identify regressions and improvements. + """Compare the most recent test run against the previous run to identify regressions and improvements. Args: test_suite: Optional test suite name to focus the comparison on. @@ -120,16 +120,10 @@ def compare_runs(test_suite: str | None = None) -> str: suite_filter = f" for suite `{test_suite}`" if test_suite else "" return f"""\ -Please compare the two most recent test runs{suite_filter} to identify regressions and improvements: +Please compare the most recent test run{suite_filter} against the previous run to identify regressions and improvements: 1. Call `get_data_inventory()` to understand the project structure. -2. Call `list_test_suites(project_code='...')` to find suites{suite_filter} and their latest runs. -3. For the most recent completed run, call `list_test_results(test_suite_id='...')` to get all results. -4. For the previous run, call `list_test_results(job_execution_id='...')` to get all results. -5. Compare the two runs: - - **Regressions:** Tests that passed before but now fail. - - **Improvements:** Tests that failed before but now pass. - - **Persistent failures:** Tests that failed in both runs. - - **Stable passes:** Tests that passed in both runs. -6. Summarize the trend and highlight any concerning regressions. +2. Call `list_test_suites(project_code='...')` to find suites{suite_filter} and their latest run IDs. +3. Call `compare_test_runs(target_job_execution_id='')` — with only the target supplied, the tool automatically diffs against the previous completed run of the same suite. +4. Summarize the trend and highlight any concerning regressions, improvements, persistent failures, or newly added/removed tests. """ diff --git a/testgen/mcp/server.py b/testgen/mcp/server.py index 36f71445..901c665b 100644 --- a/testgen/mcp/server.py +++ b/testgen/mcp/server.py @@ -202,9 +202,9 @@ def build_mcp_server( validate_custom_test, ) from testgen.mcp.tools.test_results import ( + compare_test_runs, get_failure_summary, get_failure_trend, - get_test_run_diff, list_test_result_history, list_test_results, search_test_results, @@ -247,7 +247,7 @@ def safe_prompt(fn): safe_tool(get_failure_summary) safe_tool(search_test_results) safe_tool(get_failure_trend) - safe_tool(get_test_run_diff) + safe_tool(compare_test_runs) safe_tool(get_test_type) safe_tool(get_source_data) safe_tool(get_source_data_query) diff --git a/testgen/mcp/tools/profile_history.py b/testgen/mcp/tools/profile_history.py index 19614365..17839798 100644 --- a/testgen/mcp/tools/profile_history.py +++ b/testgen/mcp/tools/profile_history.py @@ -310,8 +310,8 @@ def compare_profiling_runs( baseline_run = target_run.get_previous() if baseline_run is None: raise MCPUserError( - f"Target run `{target_job_execution_id}` is the first completed profiling run " - "on its table group — pass `baseline_job_execution_id` to compare against." + f"Target run `{target_job_execution_id}` has no earlier completed " + "profiling run on its table group to compare against." ) else: baseline_run = resolve_profiling_run(baseline_job_execution_id) diff --git a/testgen/mcp/tools/test_results.py b/testgen/mcp/tools/test_results.py index dff4331b..d35d3f0a 100644 --- a/testgen/mcp/tools/test_results.py +++ b/testgen/mcp/tools/test_results.py @@ -1,9 +1,12 @@ from datetime import UTC, datetime, timedelta +from uuid import UUID -from testgen.common.models import with_database_session +from testgen.common.enums import JobStatus +from testgen.common.models import get_current_session, with_database_session +from testgen.common.models.job_execution import JobExecution from testgen.common.models.test_definition import TestType from testgen.common.models.test_result import BucketInterval, TestResult, TestResultStatus -from testgen.common.models.test_run import TestRun +from testgen.common.models.test_run import TestRun, TestRunSummary from testgen.common.models.test_suite import TestSuite from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.permissions import get_project_permissions, mcp_permission @@ -520,66 +523,85 @@ def get_failure_trend( @with_database_session @mcp_permission("view") -def get_test_run_diff(job_execution_id_a: str, job_execution_id_b: str) -> str: +def compare_test_runs( + target_job_execution_id: str, + baseline_job_execution_id: str | None = None, +) -> str: """Compare two test runs and report regressions, improvements, persistent failures, and added/removed tests. + When ``baseline_job_execution_id`` is omitted, the baseline defaults to the immediately + previous completed test run on the same test suite as the target run. + Args: - job_execution_id_a: UUID of the older (baseline) test run, e.g. from ``list_test_runs``. - job_execution_id_b: UUID of the newer test run. + target_job_execution_id: UUID of the newer test run, e.g. from ``list_test_runs``. + baseline_job_execution_id: Optional UUID of the older test run. + When omitted, defaults to the previous completed run on the same test suite. """ - uuid_a = parse_uuid(job_execution_id_a, "job_execution_id_a") - uuid_b = parse_uuid(job_execution_id_b, "job_execution_id_b") - - run_a = TestRun.get_by_id_or_job(uuid_a) - run_b = TestRun.get_by_id_or_job(uuid_b) - - # Permission check first — unify "not found" and "inaccessible" (also covers monitor suites, - # which are hidden from this tool the same way they're hidden from the inventory tools). perms = get_project_permissions() - suite_ids = [r.test_suite_id for r in (run_a, run_b) if r is not None] - suites_by_id: dict = {} - if suite_ids: - suites_by_id = { - s.id: s for s in TestSuite.select_where(TestSuite.id.in_(suite_ids)) - } - - def _accessible(run) -> bool: + + def _resolve_accessible(je_id_str: str, je_uuid: UUID) -> TestRun: + run = TestRun.get_by_id_or_job(je_uuid) if run is None: - return False - suite = suites_by_id.get(run.test_suite_id) - if suite is None or suite.is_monitor: - return False - return perms.has_access(suite.project_code) - - if not _accessible(run_a): - raise MCPResourceNotAccessible("Test run", job_execution_id_a) - if not _accessible(run_b): - raise MCPResourceNotAccessible("Test run", job_execution_id_b) - - # Both runs confirmed accessible — safe to reveal suite IDs in the compatibility message. - if run_a.test_suite_id != run_b.test_suite_id: - raise MCPUserError( - "Both runs must belong to the same test suite to be comparable. " - f"Run A is in suite `{run_a.test_suite_id}`, run B is in suite `{run_b.test_suite_id}`. " - "Use `list_test_runs(test_suite=...)` to pick two runs of the same suite." - ) + raise MCPResourceNotAccessible("Test run", je_id_str) + suite = TestSuite.get_regular(run.test_suite_id) + if suite is None or not perms.has_access(suite.project_code): + raise MCPResourceNotAccessible("Test run", je_id_str) + return run + + def _require_completed(run: TestRun, label: str) -> None: + je = get_current_session().get(JobExecution, run.job_execution_id) + if je.status != JobStatus.COMPLETED: + status_label = TestRunSummary.STATUS_LABEL.get(je.status, je.status) + raise MCPUserError( + f"{label} run is in `{status_label}` state — comparison requires a completed run." + ) - diff = TestResult.diff_with_details(run_a.id, run_b.id) + target_uuid = parse_uuid(target_job_execution_id, "target_job_execution_id") + target_run = _resolve_accessible(target_job_execution_id, target_uuid) + _require_completed(target_run, "Target") + + if baseline_job_execution_id is None: + baseline_run = target_run.get_previous() + if baseline_run is None: + raise MCPUserError( + f"Target run `{target_job_execution_id}` has no earlier completed " + "test run on its test suite to compare against." + ) + else: + baseline_uuid = parse_uuid(baseline_job_execution_id, "baseline_job_execution_id") + baseline_run = _resolve_accessible(baseline_job_execution_id, baseline_uuid) + if baseline_run.test_suite_id != target_run.test_suite_id: + raise MCPUserError( + "Both runs must belong to the same test suite to be comparable. " + f"Target is in suite `{target_run.test_suite_id}`, " + f"baseline is in suite `{baseline_run.test_suite_id}`. " + "Use `list_test_runs(test_suite=...)` to pick two runs of the same suite." + ) + _require_completed(baseline_run, "Baseline") + + diff = TestResult.diff_with_details(baseline_run.id, target_run.id) doc = MdDoc() - doc.heading(1, "Test Run Diff") - doc.field("Test Run A", job_execution_id_a, code=True) - doc.field("Test Run B", job_execution_id_b, code=True) + doc.heading(1, "Test Run Comparison") + doc.table( + ["", "Target", "Baseline"], + [ + ["Test Run", + MdDoc.code(str(target_run.job_execution_id)), + MdDoc.code(str(baseline_run.job_execution_id))], + ["Started", target_run.test_starttime, baseline_run.test_starttime], + ], + ) doc.table( headers=["Category", "Count"], rows=[ - ["Regressions (A passed → B failed/warning)", len(diff.regressions)], - ["Improvements (A failed/warning → B passed)", len(diff.improvements)], + ["Regressions (Baseline passed → Target failed/warning)", len(diff.regressions)], + ["Improvements (Baseline failed/warning → Target passed)", len(diff.improvements)], ["Persistent failures", len(diff.persistent_failures)], - ["New tests (only in B)", len(diff.new_tests)], - ["Removed tests (only in A)", len(diff.removed_tests)], - ["Total in A", diff.total_a], - ["Total in B", diff.total_b], + ["New tests (only in Target)", len(diff.new_tests)], + ["Removed tests (only in Baseline)", len(diff.removed_tests)], + ["Total in Target", diff.total_target], + ["Total in Baseline", diff.total_baseline], ], ) @@ -588,17 +610,21 @@ def _section(title: str, rows: list) -> None: return doc.heading(2, title) doc.table( - headers=["Test Type", "Table", "Column", "A → B", "Measure A", "Measure B", "Threshold A", "Threshold B"], + headers=[ + "Test Type", "Table", "Column", "Baseline → Target", + "Measure Baseline", "Measure Target", "Threshold Baseline", "Threshold Target", + ], rows=[ [ row.test_name_short or row.test_type, row.table_name, row.column_names, - f"{row.status_a.value if row.status_a else '—'} → {row.status_b.value if row.status_b else '—'}", - row.measure_a, - row.measure_b, - row.threshold_a, - row.threshold_b, + f"{row.status_baseline.value if row.status_baseline else '—'} → " + f"{row.status_target.value if row.status_target else '—'}", + row.measure_baseline, + row.measure_target, + row.threshold_baseline, + row.threshold_target, ] for row in rows ], diff --git a/tests/unit/mcp/test_tools_profile_history.py b/tests/unit/mcp/test_tools_profile_history.py index 6bec4fe1..e82b5bf3 100644 --- a/tests/unit/mcp/test_tools_profile_history.py +++ b/tests/unit/mcp/test_tools_profile_history.py @@ -337,7 +337,7 @@ def test_compare_profiling_runs_auto_baseline_first_run(mock_resolve, db_session mock_resolve.return_value = target_run with _patch_session([_je()]): - with pytest.raises(MCPUserError, match="first completed profiling run"): + with pytest.raises(MCPUserError, match="no earlier completed profiling run"): compare_profiling_runs(str(target_run.job_execution_id)) diff --git a/tests/unit/mcp/test_tools_test_results.py b/tests/unit/mcp/test_tools_test_results.py index f89c9d1f..7b3d08e4 100644 --- a/tests/unit/mcp/test_tools_test_results.py +++ b/tests/unit/mcp/test_tools_test_results.py @@ -4,6 +4,7 @@ import pytest +from testgen.common.enums import JobStatus from testgen.common.models.test_result import TestResultStatus from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError from testgen.mcp.permissions import ProjectPermissions @@ -866,33 +867,54 @@ def test_get_failure_trend_exclude_today_shifts_end_date(mock_compute, mock_fail # ---------------------------------------------------------------------- -# get_test_run_diff +# compare_test_runs # ---------------------------------------------------------------------- -def _mock_diff_row(status_a, status_b, **overrides): +def _mock_diff_row(status_baseline, status_target, **overrides): row = MagicMock() row.test_definition_id = uuid4() row.test_type = "Pattern_Match" row.test_name_short = "Pattern Match" row.table_name = "orders" row.column_names = "customer_id" - row.status_a = status_a - row.status_b = status_b - row.measure_a = "5" - row.measure_b = "12" - row.threshold_a = "0" - row.threshold_b = "0" + row.status_baseline = status_baseline + row.status_target = status_target + row.measure_baseline = "5" + row.measure_target = "12" + row.threshold_baseline = "0" + row.threshold_target = "0" for k, v in overrides.items(): setattr(row, k, v) return row +def _mock_run(suite_id, je_id=None): + run = MagicMock(id=uuid4(), test_suite_id=suite_id) + run.job_execution_id = je_id or uuid4() + return run + + +def _je(status=JobStatus.COMPLETED): + """Build a JobExecution mock for ``session.get(JobExecution, ...)`` returns.""" + je = MagicMock() + je.status = status + return je + + +def _patch_test_results_session(jes): + """Patch ``get_current_session`` in test_results so ``session.get(JobExecution, ...)`` + returns the given JEs in order (one per ``_require_completed`` call).""" + session = MagicMock() + session.get.side_effect = jes + return patch("testgen.mcp.tools.test_results.get_current_session", return_value=session) + + @patch("testgen.mcp.tools.test_results.TestSuite") @patch("testgen.mcp.tools.test_results.TestResult") @patch("testgen.mcp.tools.test_results.TestRun") @patch("testgen.mcp.permissions._compute_project_permissions") -def test_get_test_run_diff_happy_path( +def test_compare_test_runs_happy_path( mock_compute, mock_test_run_cls, mock_result, mock_test_suite_cls, db_session_mock, ): mock_compute.return_value = ProjectPermissions( @@ -901,18 +923,21 @@ def test_get_test_run_diff_happy_path( username="test_user", ) suite_id = uuid4() - run_a = MagicMock(id=uuid4(), test_suite_id=suite_id) - run_b = MagicMock(id=uuid4(), test_suite_id=suite_id) - mock_test_run_cls.get_by_id_or_job.side_effect = [run_a, run_b] - mock_test_suite_cls.id = MagicMock() # support .in_(...) on attribute mock - mock_test_suite_cls.select_where.return_value = [MagicMock(id=suite_id, project_code="proj_a", is_monitor=False)] + baseline_run = _mock_run(suite_id) + target_run = _mock_run(suite_id) + # Tool resolves target first, then baseline. + mock_test_run_cls.get_by_id_or_job.side_effect = [target_run, baseline_run] + mock_test_suite_cls.get_regular.return_value = _mock_test_suite(suite_id=suite_id, project_code="proj_a") diff = MagicMock() - diff.total_a = 100 - diff.total_b = 100 + diff.total_baseline = 100 + diff.total_target = 100 diff.regressions = [ _mock_diff_row( - TestResultStatus.Passed, TestResultStatus.Failed, threshold_a="1", threshold_b="3", + TestResultStatus.Passed, + TestResultStatus.Failed, + threshold_baseline="1", + threshold_target="3", ) ] diff.improvements = [] @@ -921,69 +946,154 @@ def test_get_test_run_diff_happy_path( diff.removed_tests = [] mock_result.diff_with_details.return_value = diff - from testgen.mcp.tools.test_results import get_test_run_diff + from testgen.mcp.tools.test_results import compare_test_runs - out = get_test_run_diff(str(uuid4()), str(uuid4())) + with _patch_test_results_session([_je(), _je()]): + out = compare_test_runs(str(uuid4()), str(uuid4())) - assert "Test Run Diff" in out + assert "Test Run Comparison" in out assert "Regressions" in out assert "Pattern Match" in out assert "Passed → Failed" in out - assert "Threshold A" in out and "Threshold B" in out + assert "Threshold Baseline" in out and "Threshold Target" in out assert "| 1 | 3 |" in out # threshold columns populated when thresholds changed + # diff_with_details called with (baseline_run.id, target_run.id) in that order. + mock_result.diff_with_details.assert_called_once_with(baseline_run.id, target_run.id) @patch("testgen.mcp.tools.test_results.TestSuite") +@patch("testgen.mcp.tools.test_results.TestResult") @patch("testgen.mcp.tools.test_results.TestRun") @patch("testgen.mcp.permissions._compute_project_permissions") -def test_get_test_run_diff_run_not_found( +def test_compare_test_runs_single_arg_resolves_previous( + mock_compute, mock_test_run_cls, mock_result, mock_test_suite_cls, db_session_mock, +): + """Only target supplied — baseline is resolved via target_run.get_previous().""" + mock_compute.return_value = ProjectPermissions( + memberships={"proj_a": "role_a"}, + permission="view", + username="test_user", + ) + suite_id = uuid4() + target_run = _mock_run(suite_id) + baseline_run = _mock_run(suite_id) + target_run.get_previous.return_value = baseline_run + mock_test_run_cls.get_by_id_or_job.return_value = target_run + mock_test_suite_cls.get_regular.return_value = _mock_test_suite(suite_id=suite_id, project_code="proj_a") + + diff = MagicMock( + total_baseline=10, total_target=10, + regressions=[], improvements=[], persistent_failures=[], new_tests=[], removed_tests=[], + ) + mock_result.diff_with_details.return_value = diff + + from testgen.mcp.tools.test_results import compare_test_runs + + with _patch_test_results_session([_je()]): + out = compare_test_runs(str(uuid4())) + + target_run.get_previous.assert_called_once_with() + mock_result.diff_with_details.assert_called_once_with(baseline_run.id, target_run.id) + # Rendered Baseline cell shows the resolved JE ID, not an input string. + assert str(baseline_run.job_execution_id) in out + + +@patch("testgen.mcp.tools.test_results.TestSuite") +@patch("testgen.mcp.tools.test_results.TestRun") +@patch("testgen.mcp.permissions._compute_project_permissions") +def test_compare_test_runs_single_arg_no_previous_raises( mock_compute, mock_test_run_cls, mock_test_suite_cls, db_session_mock, ): - """One run missing, other accessible — unified error without leaking which side failed.""" + """Target is the oldest run — get_previous() returns None — clear user-facing error.""" mock_compute.return_value = ProjectPermissions( memberships={"proj_a": "role_a"}, permission="view", username="test_user", ) suite_id = uuid4() - mock_test_run_cls.get_by_id_or_job.side_effect = [None, MagicMock(id=uuid4(), test_suite_id=suite_id)] - mock_test_suite_cls.id = MagicMock() - mock_test_suite_cls.select_where.return_value = [MagicMock(id=suite_id, project_code="proj_a", is_monitor=False)] + target_run = _mock_run(suite_id) + target_run.get_previous.return_value = None + mock_test_run_cls.get_by_id_or_job.return_value = target_run + mock_test_suite_cls.get_regular.return_value = _mock_test_suite(suite_id=suite_id, project_code="proj_a") - from testgen.mcp.tools.test_results import get_test_run_diff + from testgen.mcp.tools.test_results import compare_test_runs + + with _patch_test_results_session([_je()]), pytest.raises(MCPUserError, match="no earlier completed test run"): + compare_test_runs(str(uuid4())) + + +@patch("testgen.mcp.tools.test_results.TestSuite") +@patch("testgen.mcp.tools.test_results.TestRun") +@patch("testgen.mcp.permissions._compute_project_permissions") +def test_compare_test_runs_single_arg_inaccessible_target( + mock_compute, mock_test_run_cls, mock_test_suite_cls, db_session_mock, +): + """Inaccessible target — error raised before get_previous() is consulted.""" + mock_compute.return_value = ProjectPermissions( + memberships={"proj_a": "role_a"}, + permission="view", + username="test_user", + ) + suite_id = uuid4() + target_run = _mock_run(suite_id) + mock_test_run_cls.get_by_id_or_job.return_value = target_run + # Monitor suite or inaccessible project — get_regular returns None either way. + mock_test_suite_cls.get_regular.return_value = None + + from testgen.mcp.tools.test_results import compare_test_runs with pytest.raises(MCPResourceNotAccessible, match="Test run .* not found or not accessible"): - get_test_run_diff(str(uuid4()), str(uuid4())) + compare_test_runs(str(uuid4())) + target_run.get_previous.assert_not_called() @patch("testgen.mcp.tools.test_results.TestSuite") @patch("testgen.mcp.tools.test_results.TestRun") @patch("testgen.mcp.permissions._compute_project_permissions") -def test_get_test_run_diff_rejects_inaccessible_project( +def test_compare_test_runs_run_not_found( mock_compute, mock_test_run_cls, mock_test_suite_cls, db_session_mock, ): - """Runs in an inaccessible project produce the same unified message, not a separate one.""" + """Target not found — unified not-found-or-inaccessible error.""" + mock_compute.return_value = ProjectPermissions( + memberships={"proj_a": "role_a"}, + permission="view", + username="test_user", + ) + mock_test_run_cls.get_by_id_or_job.return_value = None + + from testgen.mcp.tools.test_results import compare_test_runs + + with pytest.raises(MCPResourceNotAccessible, match="Test run .* not found or not accessible"): + compare_test_runs(str(uuid4()), str(uuid4())) + + +@patch("testgen.mcp.tools.test_results.TestSuite") +@patch("testgen.mcp.tools.test_results.TestRun") +@patch("testgen.mcp.permissions._compute_project_permissions") +def test_compare_test_runs_rejects_inaccessible_project( + mock_compute, mock_test_run_cls, mock_test_suite_cls, db_session_mock, +): + """Runs in an inaccessible project produce the unified message.""" mock_compute.return_value = ProjectPermissions( memberships={"proj_a": "role_a"}, permission="view", username="test_user", ) suite_id = uuid4() - run = MagicMock(id=uuid4(), test_suite_id=suite_id) - mock_test_run_cls.get_by_id_or_job.side_effect = [run, run] - mock_test_suite_cls.id = MagicMock() - mock_test_suite_cls.select_where.return_value = [MagicMock(id=suite_id, project_code="proj_forbidden", is_monitor=False)] + run = _mock_run(suite_id) + mock_test_run_cls.get_by_id_or_job.return_value = run + mock_test_suite_cls.get_regular.return_value = _mock_test_suite(suite_id=suite_id, project_code="proj_forbidden") - from testgen.mcp.tools.test_results import get_test_run_diff + from testgen.mcp.tools.test_results import compare_test_runs - with pytest.raises(MCPUserError, match="not found or not accessible"): - get_test_run_diff(str(uuid4()), str(uuid4())) + with pytest.raises(MCPResourceNotAccessible, match="not found or not accessible"): + compare_test_runs(str(uuid4()), str(uuid4())) @patch("testgen.mcp.tools.test_results.TestSuite") @patch("testgen.mcp.tools.test_results.TestRun") @patch("testgen.mcp.permissions._compute_project_permissions") -def test_get_test_run_diff_rejects_different_suites( +def test_compare_test_runs_rejects_different_suites( mock_compute, mock_test_run_cls, mock_test_suite_cls, db_session_mock, ): """Both runs accessible but in different suites → suite-mismatch error.""" @@ -992,51 +1102,97 @@ def test_get_test_run_diff_rejects_different_suites( permission="view", username="test_user", ) - suite_id_a = uuid4() - suite_id_b = uuid4() - run_a = MagicMock(id=uuid4(), test_suite_id=suite_id_a) - run_b = MagicMock(id=uuid4(), test_suite_id=suite_id_b) - mock_test_run_cls.get_by_id_or_job.side_effect = [run_a, run_b] - mock_test_suite_cls.id = MagicMock() - mock_test_suite_cls.select_where.return_value = [ - MagicMock(id=suite_id_a, project_code="proj_a", is_monitor=False), - MagicMock(id=suite_id_b, project_code="proj_a", is_monitor=False), + suite_id_target = uuid4() + suite_id_baseline = uuid4() + target_run = _mock_run(suite_id_target) + baseline_run = _mock_run(suite_id_baseline) + mock_test_run_cls.get_by_id_or_job.side_effect = [target_run, baseline_run] + mock_test_suite_cls.get_regular.side_effect = [ + _mock_test_suite(suite_id=suite_id_target, project_code="proj_a"), + _mock_test_suite(suite_id=suite_id_baseline, project_code="proj_a"), ] - from testgen.mcp.tools.test_results import get_test_run_diff + from testgen.mcp.tools.test_results import compare_test_runs - with pytest.raises(MCPUserError, match="must belong to the same test suite"): - get_test_run_diff(str(uuid4()), str(uuid4())) + with _patch_test_results_session([_je()]), pytest.raises(MCPUserError, match="must belong to the same test suite"): + compare_test_runs(str(uuid4()), str(uuid4())) -def test_get_test_run_diff_invalid_uuid(db_session_mock): - from testgen.mcp.tools.test_results import get_test_run_diff +def test_compare_test_runs_invalid_uuid(db_session_mock): + from testgen.mcp.tools.test_results import compare_test_runs with pytest.raises(MCPUserError, match="not a valid UUID"): - get_test_run_diff("bad-uuid", str(uuid4())) + compare_test_runs("bad-uuid", str(uuid4())) @patch("testgen.mcp.tools.test_results.TestSuite") @patch("testgen.mcp.tools.test_results.TestRun") @patch("testgen.mcp.permissions._compute_project_permissions") -def test_get_test_run_diff_rejects_monitor_suite( +def test_compare_test_runs_rejects_monitor_suite( mock_compute, mock_test_run_cls, mock_test_suite_cls, db_session_mock, ): - """Monitor suites are hidden from this tool, same as inaccessible projects — unified message.""" + """Monitor suites are hidden — TestSuite.get_regular returns None — unified message.""" mock_compute.return_value = ProjectPermissions( memberships={"proj_a": "role_a"}, permission="view", username="test_user", ) suite_id = uuid4() - run = MagicMock(id=uuid4(), test_suite_id=suite_id) - mock_test_run_cls.get_by_id_or_job.side_effect = [run, run] - mock_test_suite_cls.id = MagicMock() - mock_test_suite_cls.select_where.return_value = [ - MagicMock(id=suite_id, project_code="proj_a", is_monitor=True) - ] + run = _mock_run(suite_id) + mock_test_run_cls.get_by_id_or_job.return_value = run + mock_test_suite_cls.get_regular.return_value = None - from testgen.mcp.tools.test_results import get_test_run_diff + from testgen.mcp.tools.test_results import compare_test_runs - with pytest.raises(MCPUserError, match="not found or not accessible"): - get_test_run_diff(str(uuid4()), str(uuid4())) + with pytest.raises(MCPResourceNotAccessible, match="not found or not accessible"): + compare_test_runs(str(uuid4()), str(uuid4())) + + +@patch("testgen.mcp.tools.test_results.TestSuite") +@patch("testgen.mcp.tools.test_results.TestRun") +@patch("testgen.mcp.permissions._compute_project_permissions") +def test_compare_test_runs_rejects_target_not_completed( + mock_compute, mock_test_run_cls, mock_test_suite_cls, db_session_mock, +): + """Target run still Running — comparison rejected before any diff work.""" + mock_compute.return_value = ProjectPermissions( + memberships={"proj_a": "role_a"}, + permission="view", + username="test_user", + ) + suite_id = uuid4() + target_run = _mock_run(suite_id) + mock_test_run_cls.get_by_id_or_job.return_value = target_run + mock_test_suite_cls.get_regular.return_value = _mock_test_suite(suite_id=suite_id, project_code="proj_a") + + from testgen.mcp.tools.test_results import compare_test_runs + + with _patch_test_results_session([_je(status=JobStatus.RUNNING)]), \ + pytest.raises(MCPUserError, match=r"Target run is in `Running` state"): + compare_test_runs(str(uuid4())) + target_run.get_previous.assert_not_called() + + +@patch("testgen.mcp.tools.test_results.TestSuite") +@patch("testgen.mcp.tools.test_results.TestRun") +@patch("testgen.mcp.permissions._compute_project_permissions") +def test_compare_test_runs_rejects_baseline_not_completed( + mock_compute, mock_test_run_cls, mock_test_suite_cls, db_session_mock, +): + """Two-arg path: target completes the check but baseline is in Error state.""" + mock_compute.return_value = ProjectPermissions( + memberships={"proj_a": "role_a"}, + permission="view", + username="test_user", + ) + suite_id = uuid4() + target_run = _mock_run(suite_id) + baseline_run = _mock_run(suite_id) + mock_test_run_cls.get_by_id_or_job.side_effect = [target_run, baseline_run] + mock_test_suite_cls.get_regular.return_value = _mock_test_suite(suite_id=suite_id, project_code="proj_a") + + from testgen.mcp.tools.test_results import compare_test_runs + + with _patch_test_results_session([_je(), _je(status=JobStatus.ERROR)]), \ + pytest.raises(MCPUserError, match=r"Baseline run is in `Error` state"): + compare_test_runs(str(uuid4()), str(uuid4())) From 077c70d82e75368177cc2ac2edd4021981fa99ce Mon Sep 17 00:00:00 2001 From: Diogo Basto Date: Mon, 11 May 2026 17:24:24 +0100 Subject: [PATCH 40/58] feat(retention): add per-project data retention cleanup (TG-1063) --- testgen/commands/exec_job.py | 6 +- testgen/commands/job_registry.py | 49 +++- testgen/commands/run_data_cleanup.py | 121 ++++++++ testgen/commands/run_quick_start.py | 2 +- testgen/common/enums.py | 3 + testgen/common/models/__init__.py | 10 +- testgen/common/models/entity.py | 8 +- testgen/common/models/job_execution.py | 40 ++- testgen/common/models/profiling_run.py | 85 +++++- testgen/common/models/project.py | 4 +- testgen/common/models/scheduler.py | 53 +++- testgen/common/models/scores.py | 140 ++++++++- .../common/models/stg_data_chars_update.py | 32 +++ .../models/stg_functional_table_update.py | 26 ++ .../models/stg_secondary_profile_update.py | 28 ++ .../models/stg_test_definition_update.py | 30 ++ testgen/common/models/test_run.py | 76 ++++- testgen/scheduler/cli_scheduler.py | 6 +- .../030_initialize_new_schema_structure.sql | 15 +- .../040_populate_new_schema_project.sql | 12 + .../dbupgrade/0191_incremental_upgrade.sql | 27 ++ .../standalone/project_settings/index.js | 164 ++++++++++- testgen/ui/views/project_settings.py | 91 +++++- tests/unit/commands/test_exec_job.py | 79 +++--- tests/unit/commands/test_run_data_cleanup.py | 265 ++++++++++++++++++ .../unit/common/models/test_job_execution.py | 105 ++++++- tests/unit/common/models/test_scheduler.py | 137 +++++++++ tests/unit/scheduler/test_scheduler_cli.py | 30 +- tests/unit/scheduler/test_scheduler_poll.py | 14 +- tests/unit/ui/test_project_settings.py | 86 +++++- 30 files changed, 1642 insertions(+), 102 deletions(-) create mode 100644 testgen/commands/run_data_cleanup.py create mode 100644 testgen/common/models/stg_data_chars_update.py create mode 100644 testgen/common/models/stg_functional_table_update.py create mode 100644 testgen/common/models/stg_secondary_profile_update.py create mode 100644 testgen/common/models/stg_test_definition_update.py create mode 100644 testgen/template/dbupgrade/0191_incremental_upgrade.sql create mode 100644 tests/unit/commands/test_run_data_cleanup.py create mode 100644 tests/unit/common/models/test_scheduler.py diff --git a/testgen/commands/exec_job.py b/testgen/commands/exec_job.py index e18de23b..48421b0e 100644 --- a/testgen/commands/exec_job.py +++ b/testgen/commands/exec_job.py @@ -36,8 +36,8 @@ def exec_job(job_execution_id: UUID) -> None: LOG.error("Job execution %s not found", job_execution_id) sys.exit(1) - handler = JOB_DISPATCH.get(job_exec.job_key) - if not handler: + job_config = JOB_DISPATCH.get(job_exec.job_key) + if not job_config: job_exec.mark_interrupted(f"Unknown job key: {job_exec.job_key}") return @@ -49,7 +49,7 @@ def exec_job(job_execution_id: UUID) -> None: with database_session(): job_exec = JobExecution.get(job_execution_id) job_context.set(JobContext(job_id=job_execution_id, source=job_exec.source)) - handler(**job_exec.kwargs) + job_config.handler(**job_exec.kwargs) with database_session(): job_exec = JobExecution.get(job_execution_id) diff --git a/testgen/commands/job_registry.py b/testgen/commands/job_registry.py index e2bdde6c..3fa3dceb 100644 --- a/testgen/commands/job_registry.py +++ b/testgen/commands/job_registry.py @@ -1,7 +1,8 @@ """Wiring between the JobExecution engine and the concrete job handlers. Two registries keyed by `job_key`: - - `JOB_DISPATCH`: maps a job to its handler (`exec_job` resolves this). + - `JOB_DISPATCH`: maps a job to its `JobConfig` (handler + per-job metadata). + `exec_job` and the scheduler resolve this. - `JOB_FINAL_CALLBACKS`: maps a job to post-terminal-transition callbacks (notifications, follow-up job submissions). `run_final_callbacks` iterates. @@ -12,15 +13,17 @@ import logging from collections.abc import Callable +from dataclasses import dataclass from sqlalchemy import select +from testgen.commands.run_data_cleanup import run_data_cleanup from testgen.commands.run_profiling import run_profiling from testgen.commands.run_recalculate_project_scores import run_recalculate_project_scores from testgen.commands.run_score_update import run_score_update from testgen.commands.run_test_execution import run_test_execution from testgen.commands.test_generation import run_test_generation -from testgen.common.enums import JobSource, JobStatus +from testgen.common.enums import JobKey, JobSource, JobStatus from testgen.common.models import database_session from testgen.common.models.job_execution import JobExecution from testgen.common.models.profiling_run import ProfilingRun @@ -33,13 +36,31 @@ FinalCallback = Callable[[JobExecution], None] -JOB_DISPATCH: dict[str, Callable] = { - "run-profile": run_profiling, - "run-tests": run_test_execution, - "run-monitors": run_test_execution, - "run-test-generation": run_test_generation, - "run-score-update": run_score_update, - "recalculate-project-scores": run_recalculate_project_scores, + +@dataclass(frozen=True) +class JobConfig: + """Per-job-key registration metadata. + + `scheduler_source` is the value the scheduler tags `JobExecution.source` + with when it spawns this job key — ``"scheduler"`` for user-facing jobs, + ``"system"`` for system-internal jobs (e.g., retention cleanup). Read + only by the scheduler; direct `JobExecution.submit(source=...)` callers + (UI, follow-up enqueues, CLI) set their own source independently and do + not consult this field. + """ + + handler: Callable + scheduler_source: JobSource = JobSource.scheduler + + +JOB_DISPATCH: dict[JobKey, JobConfig] = { + JobKey.run_profile: JobConfig(handler=run_profiling), + JobKey.run_tests: JobConfig(handler=run_test_execution), + JobKey.run_monitors: JobConfig(handler=run_test_execution), + JobKey.run_test_generation: JobConfig(handler=run_test_generation), + JobKey.run_score_update: JobConfig(handler=run_score_update, scheduler_source=JobSource.system), + JobKey.recalculate_project_scores: JobConfig(handler=run_recalculate_project_scores, scheduler_source=JobSource.system), + JobKey.run_data_cleanup: JobConfig(handler=run_data_cleanup, scheduler_source=JobSource.system), } @@ -92,7 +113,7 @@ def _enqueue_score_update(job_exec: JobExecution) -> None: with database_session(): JobExecution.submit( - job_key="run-score-update", + job_key=JobKey.run_score_update, kwargs={ "parent_job_id": str(job_exec.id), "parent_job_key": job_exec.job_key, @@ -102,8 +123,8 @@ def _enqueue_score_update(job_exec: JobExecution) -> None: ) -JOB_FINAL_CALLBACKS: dict[str, list[FinalCallback]] = { - "run-profile": [_notify_profiling_run, _enqueue_score_update], - "run-tests": [_notify_test_run, _enqueue_score_update], - "run-monitors": [_notify_monitor_run], +JOB_FINAL_CALLBACKS: dict[JobKey, list[FinalCallback]] = { + JobKey.run_profile: [_notify_profiling_run, _enqueue_score_update], + JobKey.run_tests: [_notify_test_run, _enqueue_score_update], + JobKey.run_monitors: [_notify_monitor_run], } diff --git a/testgen/commands/run_data_cleanup.py b/testgen/commands/run_data_cleanup.py new file mode 100644 index 00000000..08e6341c --- /dev/null +++ b/testgen/commands/run_data_cleanup.py @@ -0,0 +1,121 @@ +"""Per-project data retention cleanup. + +Deletes profiling runs, test runs, and their child results older than the +project's retention period, plus aged-out staging, score history, and +job_execution records. + +Always preserves the most recent profiling run per table group and the most +recent test run per test suite (including monitor suites). Profiling is +expensive and tends to run infrequently; downstream features — test +generation, freshness monitor generation, data catalog, and MCP analysis +tools — depend on the most recent profiling result for a table group, so +the project must always retain a baseline regardless of retention period +or run cadence. +""" + +import logging +from datetime import UTC, datetime, timedelta + +from testgen.common.models import database_session +from testgen.common.models.job_execution import JobExecution +from testgen.common.models.profiling_run import ProfilingRun +from testgen.common.models.scores import ScoreDefinitionResultHistoryEntry, ScoreHistoryLatestRun +from testgen.common.models.stg_data_chars_update import StgDataCharsUpdate +from testgen.common.models.stg_functional_table_update import StgFunctionalTableUpdate +from testgen.common.models.stg_secondary_profile_update import StgSecondaryProfileUpdate +from testgen.common.models.stg_test_definition_update import StgTestDefinitionUpdate +from testgen.common.models.test_run import TestRun + +LOG = logging.getLogger("testgen") + +BATCH_SIZE = 1000 + + +def run_data_cleanup(project_code: str, retention_days: int) -> None: + started_at = datetime.now(UTC) + cutoff = started_at - timedelta(days=retention_days) + LOG.info( + "Data retention cleanup started: project=%s retention_days=%d cutoff=%s", + project_code, retention_days, cutoff.isoformat(), + ) + + with database_session(): + protected_profiling_ids = ProfilingRun.find_latest_per_table_group(project_code) + protected_test_run_ids = TestRun.find_latest_per_test_suite(project_code) + # Translate protected run ids → their job_execution_ids so the JE sweep + # can carve them out. Nulls (older runs without a JE) are filtered here. + je_map = { + **ProfilingRun.get_job_execution_ids(list(protected_profiling_ids)), + **TestRun.get_job_execution_ids(list(protected_test_run_ids)), + } + protected_job_execution_ids = {je for je in je_map.values() if je is not None} + + LOG.info( + "Protected latest runs: profiling=%d test=%d job_executions=%d", + len(protected_profiling_ids), len(protected_test_run_ids), len(protected_job_execution_ids), + ) + + # Each delete owns its per-batch transactions internally — committing + # between batches releases locks and bounds WAL growth for large sweeps. + deleted_profiling = ProfilingRun.delete_older_than( + cutoff=cutoff, + project_code=project_code, + protected_ids=protected_profiling_ids, + batch_size=BATCH_SIZE, + ) + + deleted_tests = TestRun.delete_older_than( + cutoff=cutoff, + project_code=project_code, + protected_ids=protected_test_run_ids, + batch_size=BATCH_SIZE, + ) + + deleted_job_executions = JobExecution.delete_older_than( + cutoff=cutoff, + project_code=project_code, + protected_ids=protected_job_execution_ids, + batch_size=BATCH_SIZE, + ) + + # Score history: read protected mapping keys BEFORE deleting from either + # table — we need score_history_latest_runs intact to compute the carve-out + # for score_definition_results_history. + with database_session(): + protected_history_keys = ScoreHistoryLatestRun.find_protected_keys( + protected_profiling_ids=protected_profiling_ids, + protected_test_run_ids=protected_test_run_ids, + ) + + deleted_score_history = ScoreDefinitionResultHistoryEntry.delete_older_than( + cutoff=cutoff, + project_code=project_code, + protected_keys=protected_history_keys, + batch_size=BATCH_SIZE, + ) + + deleted_score_latest = ScoreHistoryLatestRun.delete_older_than( + cutoff=cutoff, + project_code=project_code, + protected_keys=protected_history_keys, + batch_size=BATCH_SIZE, + ) + + # Staging tables: defensive cleanup of orphans left behind by failed jobs. + # No carve-out — these are transient operational rows with no run linkage. + with database_session(): + deleted_stg = ( + StgSecondaryProfileUpdate.delete_older_than(cutoff, project_code) + + StgFunctionalTableUpdate.delete_older_than(cutoff, project_code) + + StgDataCharsUpdate.delete_older_than(cutoff, project_code) + + StgTestDefinitionUpdate.delete_older_than(cutoff, project_code) + ) + + elapsed = (datetime.now(UTC) - started_at).total_seconds() + LOG.info( + "Data retention cleanup complete: project=%s " + "deleted_profiling=%d deleted_tests=%d deleted_job_executions=%d " + "deleted_score_history=%d deleted_score_latest=%d deleted_staging=%d elapsed=%.1fs", + project_code, deleted_profiling, deleted_tests, deleted_job_executions, + deleted_score_history, deleted_score_latest, deleted_stg, elapsed, + ) diff --git a/testgen/commands/run_quick_start.py b/testgen/commands/run_quick_start.py index ae4940ff..2d56f312 100644 --- a/testgen/commands/run_quick_start.py +++ b/testgen/commands/run_quick_start.py @@ -68,7 +68,7 @@ def run_with_job_execution( je_id = je.id job_context.set(JobContext(job_id=je_id, source=source)) - JOB_DISPATCH[job_key](**handler_kwargs, run_date=run_date) + JOB_DISPATCH[job_key].handler(**handler_kwargs, run_date=run_date) with database_session(): je = JobExecution.get(je_id) diff --git a/testgen/common/enums.py b/testgen/common/enums.py index 15368356..3679f709 100644 --- a/testgen/common/enums.py +++ b/testgen/common/enums.py @@ -35,6 +35,9 @@ class JobKey(StrEnum): run_tests = "run-tests" run_monitors = "run-monitors" run_test_generation = "run-test-generation" + run_score_update = "run-score-update" + recalculate_project_scores = "recalculate-project-scores" + run_data_cleanup = "run-data-cleanup" class JobSource(StrEnum): diff --git a/testgen/common/models/__init__.py b/testgen/common/models/__init__.py index 4fe7211f..023f6478 100644 --- a/testgen/common/models/__init__.py +++ b/testgen/common/models/__init__.py @@ -4,7 +4,7 @@ import threading import urllib.parse -from sqlalchemy import create_engine +from sqlalchemy import create_engine, delete from sqlalchemy.orm import DeclarativeBase, sessionmaker from sqlalchemy.orm import Session as SQLAlchemySession @@ -30,6 +30,14 @@ class Base(DeclarativeBase): # Can be removed once all models use Mapped[] annotations. __allow_unmapped__ = True + @classmethod + def delete_where(cls, *clauses) -> int: + """Single-statement DELETE on this model filtered by ``clauses``; + returns the row count. Callers may ignore the return when not needed. + """ + result = get_current_session().execute(delete(cls).where(*clauses)) + return result.rowcount or 0 + Session = sessionmaker( engine, expire_on_commit=False, diff --git a/testgen/common/models/entity.py b/testgen/common/models/entity.py index 263c4f5a..248d0d2c 100644 --- a/testgen/common/models/entity.py +++ b/testgen/common/models/entity.py @@ -3,7 +3,7 @@ from typing import Any, Self from uuid import UUID -from sqlalchemy import delete, func, select +from sqlalchemy import func, select from sqlalchemy.dialects import postgresql from sqlalchemy.orm import InstrumentedAttribute from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList @@ -147,12 +147,6 @@ def _paginate( def has_running_process(cls, ids: list[str]) -> bool: raise NotImplementedError - @classmethod - def delete_where(cls, *clauses) -> None: - query = delete(cls).where(*clauses) - db_session = get_current_session() - db_session.execute(query) - @classmethod def is_in_use(cls, ids: list[str]) -> bool: raise NotImplementedError diff --git a/testgen/common/models/job_execution.py b/testgen/common/models/job_execution.py index 0f2f2660..8d8160c2 100644 --- a/testgen/common/models/job_execution.py +++ b/testgen/common/models/job_execution.py @@ -3,11 +3,11 @@ from typing import Any, ClassVar, Self from uuid import UUID, uuid4 -from sqlalchemy import Column, String, Text, case, func, select, text, update +from sqlalchemy import Column, String, Text, case, delete, func, select, text, update from sqlalchemy.dialects import postgresql from testgen.common.enums import JobKey, JobSource, JobStatus -from testgen.common.models import Base, get_current_session +from testgen.common.models import Base, database_session, get_current_session LOG = logging.getLogger("testgen") @@ -148,6 +148,42 @@ def get(cls, execution_id: UUID) -> Self | None: session = get_current_session() return session.get(cls, execution_id) + @classmethod + def delete_older_than( + cls, + cutoff: datetime, + project_code: str, + protected_ids: set[UUID], + batch_size: int = 1000, + ) -> int: + """Batched delete of terminal-state job executions older than cutoff for + the given project, excluding protected ids. Returns total rows deleted. + + Skips rows in non-terminal states (pending/claimed/running/cancel_requested) — + those represent live work and must not be removed regardless of age. + + Each batch runs in its own transaction (committed before the next batch + is selected), so locks on job_executions are released between batches + and WAL growth stays bounded for large sweeps. + """ + where_clauses = [ + cls.project_code == project_code, + cls.completed_at < cutoff, + cls.status.in_([JobStatus.COMPLETED, JobStatus.ERROR, JobStatus.CANCELED]), + ] + if protected_ids: + where_clauses.append(cls.id.notin_(protected_ids)) + + total = 0 + while True: + with database_session() as session: + ids = session.scalars(select(cls.id).where(*where_clauses).limit(batch_size)).all() + if not ids: + break + session.execute(delete(cls).where(cls.id.in_(ids))) + total += len(ids) + return total + @classmethod def list_for_project( cls, diff --git a/testgen/common/models/profiling_run.py b/testgen/common/models/profiling_run.py index 2881e673..225c9942 100644 --- a/testgen/common/models/profiling_run.py +++ b/testgen/common/models/profiling_run.py @@ -11,7 +11,7 @@ from sqlalchemy.sql.expression import case from testgen.common.enums import Disposition, JobStatus -from testgen.common.models import get_current_session +from testgen.common.models import database_session, get_current_session from testgen.common.models.connection import Connection from testgen.common.models.entity import Entity, EntityMinimal from testgen.common.models.job_execution import JobExecution @@ -389,6 +389,89 @@ def cascade_delete(cls, ids: list[str]) -> None: db_session.execute(text(query), {"profiling_run_ids": tuple(ids)}) cls.delete_where(cls.id.in_(ids)) + @classmethod + def find_latest_per_table_group(cls, project_code: str) -> set[UUID]: + """Return the latest completed profiling run id per table group for the + project. + + Used by data retention to protect at least one run per scope. Profiling + is expensive and runs infrequently; downstream features (test + generation, freshness monitor generation, data catalog, MCP analysis + tools) read the most recent profiling result for a table group, so the + latest usable snapshot must survive even when its run_date is past the + retention cutoff. Failed and in-flight runs are skipped because they + don't expose result data for downstream consumers to read. + """ + rows = get_current_session().scalars( + select(cls.id) + .join(JobExecution, cls.job_execution_id == JobExecution.id) + .where( + cls.project_code == project_code, + JobExecution.status == JobStatus.COMPLETED, + ) + .order_by(cls.table_groups_id, cls.profiling_starttime.desc()) + .distinct(cls.table_groups_id) + ).all() + return set(rows) + + @classmethod + def delete_older_than( + cls, + cutoff: datetime, + project_code: str, + protected_ids: set[UUID], + batch_size: int = 1000, + dry_run: bool = False, + ) -> int: + """Batched delete of profiling runs (with cascading children) older than + cutoff for the given project, excluding protected ids. Returns total + parent rows deleted across all batches — or, with ``dry_run=True``, + the number that would be deleted (for retention preview, no writes). + + In-flight runs (JE in PENDING/CLAIMED/RUNNING/CANCEL_REQUESTED) are + never deleted — they may still be writing data. + + Each batch runs in its own transaction (committed before the next batch + is selected), so locks on profiling_runs / profile_results / etc. are + released between batches and WAL growth stays bounded for large sweeps. + """ + where_clauses = [ + cls.project_code == project_code, + cls.profiling_starttime < cutoff, + JobExecution.status.in_([JobStatus.COMPLETED, JobStatus.ERROR, JobStatus.CANCELED]), + ] + if protected_ids: + where_clauses.append(cls.id.notin_(protected_ids)) + + base_select = select(cls.id).join(JobExecution, cls.job_execution_id == JobExecution.id) + + if dry_run: + return get_current_session().scalar( + select(func.count()).select_from(base_select.where(*where_clauses).subquery()) + ) or 0 + + total = 0 + while True: + with database_session() as session: + ids = session.scalars(base_select.where(*where_clauses).limit(batch_size)).all() + if not ids: + break + cls.cascade_delete([str(i) for i in ids]) + total += len(ids) + return total + + @classmethod + def get_job_execution_ids(cls, profiling_run_ids: list[UUID]) -> dict[UUID, UUID | None]: + """Map profiling_run PKs to their job_execution_ids (batch lookup). + + Mirrors TestRun.get_job_execution_ids. + """ + if not profiling_run_ids: + return {} + query = select(cls.id, cls.job_execution_id).where(cls.id.in_(profiling_run_ids)) + rows = get_current_session().execute(query).all() + return {row.id: row.job_execution_id for row in rows} + def init_progress(self) -> None: self._progress = { "data_chars": {"label": "Refreshing data catalog"}, diff --git a/testgen/common/models/project.py b/testgen/common/models/project.py index 5c54872a..40d04974 100644 --- a/testgen/common/models/project.py +++ b/testgen/common/models/project.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from uuid import UUID, uuid4 -from sqlalchemy import Boolean, Column, String, asc, func, select, text +from sqlalchemy import Boolean, Column, Integer, String, asc, func, select, text from sqlalchemy.dialects import postgresql from testgen.common.models import get_current_session @@ -40,6 +40,8 @@ class Project(Entity): observability_api_url: str = Column(NullIfEmptyString) observability_api_key: str = Column(NullIfEmptyString) use_dq_score_weights: bool = Column(Boolean, default=True) + data_retention_enabled: bool = Column(Boolean, nullable=False, default=True) + data_retention_days: int | None = Column(Integer, default=180) _get_by = "project_code" _default_order_by = (asc(func.lower(project_name)),) diff --git a/testgen/common/models/scheduler.py b/testgen/common/models/scheduler.py index f294cf0a..d1c03657 100644 --- a/testgen/common/models/scheduler.py +++ b/testgen/common/models/scheduler.py @@ -17,6 +17,11 @@ RUN_MONITORS_JOB_KEY = "run-monitors" RUN_PROFILE_JOB_KEY = "run-profile" +DEFAULT_DATA_CLEANUP_CRON = "0 1 * * *" +# Non-UI fallback for retention schedule timezone. UI surfaces should instead +# default to the user's browser timezone (resolved client-side). +DEFAULT_RETENTION_CRON_TZ = "UTC" + SCHEDULABLE_JOB_KEYS: frozenset[JobKey] = frozenset({JobKey.run_profile, JobKey.run_tests}) @@ -129,7 +134,51 @@ def select_active_by_kwargs( else: query = query.where(cls.kwargs[k].astext == str(v)) return list(get_current_session().scalars(query).all()) - + + @classmethod + def upsert_for_retention( + cls, + project_code: str, + retention_days: int, + cron_expr: str, + cron_tz: str, + ) -> Self: + """Create or update the data-retention schedule for a project. + + Idempotent — safe to call on project creation and on every retention + settings save. Uniquely keyed by (project_code, JobKey.run_data_cleanup). + """ + session = get_current_session() + schedule = session.scalars( + select(cls).where(cls.project_code == project_code, cls.key == JobKey.run_data_cleanup) + ).first() + kwargs = {"project_code": project_code, "retention_days": retention_days} + if schedule: + schedule.kwargs = kwargs + schedule.cron_expr = cron_expr + schedule.cron_tz = cron_tz + schedule.active = True + else: + schedule = cls( + project_code=project_code, + key=JobKey.run_data_cleanup, + kwargs=kwargs, + cron_expr=cron_expr, + cron_tz=cron_tz, + active=True, + ) + session.add(schedule) + return schedule + + @classmethod + def delete_for_retention(cls, project_code: str) -> None: + """Remove the data-retention schedule for a project (when retention is + disabled or the project is deleted). + """ + get_current_session().execute( + delete(cls).where(cls.project_code == project_code, cls.key == JobKey.run_data_cleanup) + ) + def get_sample_triggering_timestamps(self, n=3) -> list[datetime]: schedule = Cron(cron_string=self.cron_expr).schedule(timezone_str=self.cron_tz) return [schedule.next() for _ in range(n)] @@ -137,7 +186,7 @@ def get_sample_triggering_timestamps(self, n=3) -> list[datetime]: @property def cron_tz_str(self) -> str: return self.cron_tz.replace("_", " ") - + def save(self) -> None: db_session = get_current_session() db_session.add(self) diff --git a/testgen/common/models/scores.py b/testgen/common/models/scores.py index 596361c3..e2df3df7 100644 --- a/testgen/common/models/scores.py +++ b/testgen/common/models/scores.py @@ -24,15 +24,17 @@ column, delete, func, + or_, select, table, text, + tuple_, ) from sqlalchemy.dialects import postgresql from sqlalchemy.orm import aliased, attributes, joinedload, relationship from testgen.common import read_template_sql_file -from testgen.common.models import Base, get_current_session +from testgen.common.models import Base, database_session, get_current_session from testgen.utils import is_uuid4 SCORE_CATEGORIES = [ @@ -878,6 +880,142 @@ def add_as_cutoff(self): session = get_current_session() session.execute(text(query), params) + @classmethod + def delete_older_than( + cls, + cutoff: datetime, + project_code: str, + protected_keys: set[tuple[UUID, datetime]], + batch_size: int = 1000, + ) -> int: + """Batched delete of score-history entries older than cutoff for the + given project, excluding entries whose (definition_id, last_run_time) + is in protected_keys. Preserves snapshots tied to protected latest + runs so the score-trend chart stays consistent with the run. + + Each batch runs in its own transaction (committed before the next batch + is selected) so locks and WAL growth stay bounded for large sweeps. + """ + project_def_ids = select(ScoreDefinition.id).where( + ScoreDefinition.project_code == project_code + ).scalar_subquery() + + where_clauses = [ + cls.last_run_time < cutoff, + cls.definition_id.in_(project_def_ids), + ] + if protected_keys: + where_clauses.append( + tuple_(cls.definition_id, cls.last_run_time).notin_(list(protected_keys)) + ) + + total = 0 + while True: + with database_session() as session: + keys = session.execute( + select(cls.definition_id, cls.last_run_time) + .where(*where_clauses) + .distinct() + .limit(batch_size) + ).all() + if not keys: + break + result = session.execute( + delete(cls).where( + tuple_(cls.definition_id, cls.last_run_time).in_(list(keys)) + ) + ) + total += result.rowcount or 0 + return total + + +class ScoreHistoryLatestRun(Base): + """Snapshot mapping rows: for a score definition + cutoff time, holds the + latest profiling/test run ids active at that point. Score-trend snapshots + in score_definition_results_history correlate to runs through this table. + + The underlying table has no real primary key — the composite declared here + captures the semantic uniqueness (one row per definition x cutoff x scope). + """ + + __tablename__ = "score_history_latest_runs" + + definition_id: UUID = Column(postgresql.UUID(as_uuid=True), nullable=False, primary_key=True) + score_history_cutoff_time: datetime = Column(DateTime(timezone=False), nullable=False, primary_key=True) + table_groups_id: UUID | None = Column(postgresql.UUID(as_uuid=True), nullable=True, primary_key=True) + last_profiling_run_id: UUID | None = Column(postgresql.UUID(as_uuid=True), nullable=True) + test_suite_id: UUID | None = Column(postgresql.UUID(as_uuid=True), nullable=True, primary_key=True) + last_test_run_id: UUID | None = Column(postgresql.UUID(as_uuid=True), nullable=True) + + @classmethod + def find_protected_keys( + cls, + protected_profiling_ids: set[UUID], + protected_test_run_ids: set[UUID], + ) -> set[tuple[UUID, datetime]]: + """Return (definition_id, score_history_cutoff_time) pairs that map to + any protected profiling or test run. Used to preserve score-trend + snapshots tied to runs that retention is keeping alive. + """ + if not protected_profiling_ids and not protected_test_run_ids: + return set() + clauses = [] + if protected_profiling_ids: + clauses.append(cls.last_profiling_run_id.in_(protected_profiling_ids)) + if protected_test_run_ids: + clauses.append(cls.last_test_run_id.in_(protected_test_run_ids)) + rows = get_current_session().execute( + select(cls.definition_id, cls.score_history_cutoff_time).where(or_(*clauses)).distinct() + ).all() + return {tuple(row) for row in rows} + + @classmethod + def delete_older_than( + cls, + cutoff: datetime, + project_code: str, + protected_keys: set[tuple[UUID, datetime]], + batch_size: int = 1000, + ) -> int: + """Batched delete of mapping rows older than cutoff for the given + project, excluding rows whose (definition_id, cutoff_time) is in + protected_keys. + + Each batch runs in its own transaction (committed before the next batch + is selected) so locks and WAL growth stay bounded for large sweeps. + """ + project_def_ids = select(ScoreDefinition.id).where( + ScoreDefinition.project_code == project_code + ).scalar_subquery() + + where_clauses = [ + cls.score_history_cutoff_time < cutoff, + cls.definition_id.in_(project_def_ids), + ] + if protected_keys: + where_clauses.append( + tuple_(cls.definition_id, cls.score_history_cutoff_time).notin_(list(protected_keys)) + ) + + total = 0 + while True: + with database_session() as session: + keys = session.execute( + select(cls.definition_id, cls.score_history_cutoff_time) + .where(*where_clauses) + .distinct() + .limit(batch_size) + ).all() + if not keys: + break + result = session.execute( + delete(cls).where( + tuple_(cls.definition_id, cls.score_history_cutoff_time).in_(list(keys)) + ) + ) + total += result.rowcount or 0 + return total + class ScoreCard(TypedDict): id: str diff --git a/testgen/common/models/stg_data_chars_update.py b/testgen/common/models/stg_data_chars_update.py new file mode 100644 index 00000000..4da2c6d1 --- /dev/null +++ b/testgen/common/models/stg_data_chars_update.py @@ -0,0 +1,32 @@ +"""ORM model for the stg_data_chars_updates staging table. + +Cleaned per-run by `data_chars_staging_delete.sql`; this model exists for +data retention to age out orphans left by failed/interrupted profiling runs. +Has no project_code column — project scope is enforced via a subquery on +table_groups. PK declared is cosmetic; only WHERE columns are needed for +bulk DELETE. +""" + +from datetime import datetime +from uuid import UUID + +from sqlalchemy import Column, String, select +from sqlalchemy.dialects import postgresql + +from testgen.common.models import Base +from testgen.common.models.table_group import TableGroup + + +class StgDataCharsUpdate(Base): + __tablename__ = "stg_data_chars_updates" + + table_groups_id: UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, nullable=False) + run_date: datetime = Column(postgresql.TIMESTAMP, primary_key=True, nullable=False) + schema_name: str = Column(String(120), primary_key=True) + table_name: str = Column(String(120), primary_key=True) + column_name: str = Column(String(120), primary_key=True) + + @classmethod + def delete_older_than(cls, cutoff: datetime, project_code: str) -> int: + project_table_groups = select(TableGroup.id).where(TableGroup.project_code == project_code) + return cls.delete_where(cls.run_date < cutoff, cls.table_groups_id.in_(project_table_groups)) diff --git a/testgen/common/models/stg_functional_table_update.py b/testgen/common/models/stg_functional_table_update.py new file mode 100644 index 00000000..783f949f --- /dev/null +++ b/testgen/common/models/stg_functional_table_update.py @@ -0,0 +1,26 @@ +"""ORM model for the stg_functional_table_updates staging table. + +Unlike the other staging tables, this one has no per-run delete anywhere in +the codebase — rows accumulate indefinitely. Data retention is the primary +cleanup. PK declared is cosmetic; only WHERE columns are needed for bulk DELETE. +""" + +from datetime import datetime + +from sqlalchemy import Column, String +from sqlalchemy.dialects import postgresql + +from testgen.common.models import Base + + +class StgFunctionalTableUpdate(Base): + __tablename__ = "stg_functional_table_updates" + + project_code: str = Column(String(30), primary_key=True, nullable=False) + run_date: datetime = Column(postgresql.TIMESTAMP, primary_key=True, nullable=False) + schema_name: str = Column(String(50), primary_key=True) + table_name: str = Column(String(120), primary_key=True) + + @classmethod + def delete_older_than(cls, cutoff: datetime, project_code: str) -> int: + return cls.delete_where(cls.run_date < cutoff, cls.project_code == project_code) diff --git a/testgen/common/models/stg_secondary_profile_update.py b/testgen/common/models/stg_secondary_profile_update.py new file mode 100644 index 00000000..68705362 --- /dev/null +++ b/testgen/common/models/stg_secondary_profile_update.py @@ -0,0 +1,28 @@ +"""ORM model for the stg_secondary_profile_updates staging table. + +Cleaned per-run by `secondary_profiling_delete.sql`; this model exists for +data retention to age out orphans left by failed/interrupted profiling runs. +The PK declared here is cosmetic — only the WHERE columns are needed for the +bulk DELETE. See `staging` package docs in `run_data_cleanup.py` for context. +""" + +from datetime import datetime + +from sqlalchemy import Column, String +from sqlalchemy.dialects import postgresql + +from testgen.common.models import Base + + +class StgSecondaryProfileUpdate(Base): + __tablename__ = "stg_secondary_profile_updates" + + project_code: str = Column(String(30), primary_key=True, nullable=False) + run_date: datetime = Column(postgresql.TIMESTAMP, primary_key=True, nullable=False) + schema_name: str = Column(String(50), primary_key=True) + table_name: str = Column(String(120), primary_key=True) + column_name: str = Column(String(120), primary_key=True) + + @classmethod + def delete_older_than(cls, cutoff: datetime, project_code: str) -> int: + return cls.delete_where(cls.run_date < cutoff, cls.project_code == project_code) diff --git a/testgen/common/models/stg_test_definition_update.py b/testgen/common/models/stg_test_definition_update.py new file mode 100644 index 00000000..4da09f52 --- /dev/null +++ b/testgen/common/models/stg_test_definition_update.py @@ -0,0 +1,30 @@ +"""ORM model for the stg_test_definition_updates staging table. + +Cleaned per-run by `delete_staging_test_definitions.sql`; this model exists +for data retention to age out orphans left by failed/interrupted prediction +runs. Has no project_code column — project scope is enforced via a subquery +on test_suites. PK declared is cosmetic; only WHERE columns are needed for +bulk DELETE. +""" + +from datetime import datetime +from uuid import UUID + +from sqlalchemy import Column, select +from sqlalchemy.dialects import postgresql + +from testgen.common.models import Base +from testgen.common.models.test_suite import TestSuite + + +class StgTestDefinitionUpdate(Base): + __tablename__ = "stg_test_definition_updates" + + test_suite_id: UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, nullable=False) + test_definition_id: UUID = Column(postgresql.UUID(as_uuid=True), primary_key=True, nullable=False) + run_date: datetime = Column(postgresql.TIMESTAMP, primary_key=True, nullable=False) + + @classmethod + def delete_older_than(cls, cutoff: datetime, project_code: str) -> int: + project_test_suites = select(TestSuite.id).where(TestSuite.project_code == project_code) + return cls.delete_where(cls.run_date < cutoff, cls.test_suite_id.in_(project_test_suites)) diff --git a/testgen/common/models/test_run.py b/testgen/common/models/test_run.py index cb701576..2bb9bc51 100644 --- a/testgen/common/models/test_run.py +++ b/testgen/common/models/test_run.py @@ -9,7 +9,7 @@ from sqlalchemy.sql.expression import case from testgen.common.enums import JobStatus -from testgen.common.models import get_current_session +from testgen.common.models import database_session, get_current_session from testgen.common.models.connection import Connection from testgen.common.models.entity import Entity, EntityMinimal from testgen.common.models.job_execution import JobExecution @@ -404,6 +404,80 @@ def cascade_delete(cls, ids: list[str]) -> None: db_session.execute(text(query), {"test_run_ids": tuple(ids)}) cls.delete_where(cls.id.in_(ids)) + @classmethod + def find_latest_per_test_suite(cls, project_code: str) -> set[UUID]: + """Return the latest completed test run id per test suite for the + project. + + Includes monitor suites (`is_monitor=True`). Used by data retention to + protect at least one run per scope so each suite keeps a usable + baseline when retention sweeps clear older history. Failed and + in-flight runs are skipped. + """ + rows = get_current_session().scalars( + select(cls.id) + .join(TestSuite, cls.test_suite_id == TestSuite.id) + .join(JobExecution, cls.job_execution_id == JobExecution.id) + .where( + TestSuite.project_code == project_code, + JobExecution.status == JobStatus.COMPLETED, + ) + .order_by(cls.test_suite_id, cls.test_starttime.desc()) + .distinct(cls.test_suite_id) + ).all() + return set(rows) + + @classmethod + def delete_older_than( + cls, + cutoff: datetime, + project_code: str, + protected_ids: set[UUID], + batch_size: int = 1000, + dry_run: bool = False, + ) -> int: + """Batched delete of test runs (with cascading children) older than + cutoff for the given project, excluding protected ids. Returns total + parent rows deleted across all batches — or, with ``dry_run=True``, + the number that would be deleted (for retention preview, no writes). + + In-flight runs (JE in PENDING/CLAIMED/RUNNING/CANCEL_REQUESTED) are + never deleted — they may still be writing data. + + Each batch runs in its own transaction (committed before the next batch + is selected), so locks on test_runs / test_results / etc. are released + between batches and WAL growth stays bounded for large sweeps. + """ + where_clauses = [ + TestSuite.project_code == project_code, + cls.test_starttime < cutoff, + JobExecution.status.in_([JobStatus.COMPLETED, JobStatus.ERROR, JobStatus.CANCELED]), + ] + if protected_ids: + where_clauses.append(cls.id.notin_(protected_ids)) + + base_select = ( + select(cls.id) + .join(TestSuite, cls.test_suite_id == TestSuite.id) + .join(JobExecution, cls.job_execution_id == JobExecution.id) + ) + + if dry_run: + return get_current_session().scalar( + select(func.count()).select_from(base_select.where(*where_clauses).subquery()) + ) or 0 + + total = 0 + while True: + with database_session() as session: + ids = session.scalars(base_select.where(*where_clauses).limit(batch_size)).all() + if not ids: + break + cls.cascade_delete([str(i) for i in ids]) + total += len(ids) + return total + + def init_progress(self) -> None: self._progress = { "data_chars": {"label": "Refreshing data catalog"}, diff --git a/testgen/scheduler/cli_scheduler.py b/testgen/scheduler/cli_scheduler.py index cfa18657..76375138 100644 --- a/testgen/scheduler/cli_scheduler.py +++ b/testgen/scheduler/cli_scheduler.py @@ -12,7 +12,7 @@ from testgen import settings from testgen.commands.job_registry import JOB_DISPATCH, run_final_callbacks -from testgen.common.enums import JobSource, JobStatus +from testgen.common.enums import JobStatus from testgen.common.models import database_session, with_database_session from testgen.common.models.job_execution import JobExecution from testgen.common.models.scheduler import JobSchedule @@ -79,7 +79,7 @@ def start_job(self, job: CliJob, triggering_time: datetime) -> None: JobExecution.submit( job_key=job.key, kwargs=job.kwargs, - source=JobSource.scheduler, + source=JOB_DISPATCH[job.key].scheduler_source, project_code=job.project_code, job_schedule_id=job.job_schedule_id, ) @@ -248,5 +248,3 @@ def run_scheduler(): scheduler = CliScheduler() scheduler.run() - - diff --git a/testgen/template/dbsetup/030_initialize_new_schema_structure.sql b/testgen/template/dbsetup/030_initialize_new_schema_structure.sql index 169704bb..75dd441f 100644 --- a/testgen/template/dbsetup/030_initialize_new_schema_structure.sql +++ b/testgen/template/dbsetup/030_initialize_new_schema_structure.sql @@ -50,14 +50,16 @@ CREATE TABLE stg_test_definition_updates ( ); CREATE TABLE projects ( - id UUID DEFAULT gen_random_uuid(), - project_code VARCHAR(30) NOT NULL + id UUID DEFAULT gen_random_uuid(), + project_code VARCHAR(30) NOT NULL CONSTRAINT projects_project_code_pk PRIMARY KEY, - project_name VARCHAR(50), + project_name VARCHAR(50), observability_api_key TEXT, observability_api_url TEXT DEFAULT '', - use_dq_score_weights BOOLEAN DEFAULT TRUE + use_dq_score_weights BOOLEAN DEFAULT TRUE, + data_retention_enabled BOOLEAN NOT NULL DEFAULT TRUE, + data_retention_days INTEGER DEFAULT 180 ); CREATE TABLE connections ( @@ -966,6 +968,9 @@ CREATE INDEX ix_dsl_tg_tcd CREATE INDEX ix_prun_pc_con ON profiling_runs(project_code, connection_id); +CREATE INDEX ix_prun_pc_starttime + ON profiling_runs(project_code, profiling_starttime); + CREATE INDEX ix_prun_tg ON profiling_runs(table_groups_id); @@ -1093,6 +1098,8 @@ CREATE TABLE job_executions ( CREATE INDEX idx_job_executions_poll ON job_executions (status, created_at) WHERE status = 'pending'; CREATE INDEX idx_job_executions_schedule ON job_executions (job_schedule_id); CREATE INDEX idx_job_executions_project ON job_executions (project_code, created_at DESC); +CREATE INDEX idx_job_executions_project_completed + ON job_executions (project_code, completed_at); CREATE TABLE settings ( key VARCHAR(50) NOT NULL PRIMARY KEY, diff --git a/testgen/template/dbsetup/040_populate_new_schema_project.sql b/testgen/template/dbsetup/040_populate_new_schema_project.sql index 36f6a30c..0e3e5b7c 100644 --- a/testgen/template/dbsetup/040_populate_new_schema_project.sql +++ b/testgen/template/dbsetup/040_populate_new_schema_project.sql @@ -7,6 +7,18 @@ SELECT '{PROJECT_CODE}' as project_code, '{OBSERVABILITY_API_KEY}' as observability_api_key, '{OBSERVABILITY_API_URL}' as observability_api_url; +-- Seed the data retention schedule so the default project's cleanup job +-- runs out of the box (matches the column defaults: enabled, 180 days). +INSERT INTO job_schedules + (id, project_code, key, kwargs, cron_expr, cron_tz, active) +SELECT gen_random_uuid(), + '{PROJECT_CODE}', + 'run-data-cleanup', + jsonb_build_object('project_code', '{PROJECT_CODE}', 'retention_days', 180), + '0 1 * * *', + 'UTC', + TRUE; + WITH inserted_user AS ( INSERT INTO auth_users diff --git a/testgen/template/dbupgrade/0191_incremental_upgrade.sql b/testgen/template/dbupgrade/0191_incremental_upgrade.sql new file mode 100644 index 00000000..84f389c7 --- /dev/null +++ b/testgen/template/dbupgrade/0191_incremental_upgrade.sql @@ -0,0 +1,27 @@ +SET SEARCH_PATH TO {SCHEMA_NAME}; + +-- Add data retention settings to projects. +-- Existing projects start disabled (NULL days); new projects default to enabled at 180 days, +-- enforced via ALTER COLUMN SET DEFAULT after the initial backfill. + +ALTER TABLE projects + ADD COLUMN IF NOT EXISTS data_retention_enabled BOOLEAN NOT NULL DEFAULT FALSE; + +ALTER TABLE projects + ALTER COLUMN data_retention_enabled SET DEFAULT TRUE; + +ALTER TABLE projects + ADD COLUMN IF NOT EXISTS data_retention_days INTEGER; + +ALTER TABLE projects + ALTER COLUMN data_retention_days SET DEFAULT 180; + +-- Indexes supporting data retention sweeps. +-- profiling_runs: retention filters by (project_code, profiling_starttime). +CREATE INDEX IF NOT EXISTS ix_prun_pc_starttime + ON profiling_runs(project_code, profiling_starttime); + +-- job_executions: supports retention queries filtering by +-- (project_code, completed_at). +CREATE INDEX IF NOT EXISTS idx_job_executions_project_completed + ON job_executions(project_code, completed_at); diff --git a/testgen/ui/components/frontend/standalone/project_settings/index.js b/testgen/ui/components/frontend/standalone/project_settings/index.js index 447641f8..0d601c61 100644 --- a/testgen/ui/components/frontend/standalone/project_settings/index.js +++ b/testgen/ui/components/frontend/standalone/project_settings/index.js @@ -5,10 +5,14 @@ import van from '/app/static/js/van.min.js'; import { Card } from '/app/static/js/components/card.js'; import { Input } from '/app/static/js/components/input.js'; import { Button } from '/app/static/js/components/button.js'; -import { required } from '/app/static/js/form_validators.js'; +import { numberBetween, required } from '/app/static/js/form_validators.js'; import { Alert } from '/app/static/js/components/alert.js'; import { Checkbox } from '/app/static/js/components/checkbox.js'; -import { createEmitter, getValue, isEqual } from '/app/static/js/utils.js'; +import { CrontabInput } from '/app/static/js/components/crontab_input.js'; +import { Select } from '/app/static/js/components/select.js'; +import { timezones } from '/app/static/js/values.js'; +import { formatTimestamp } from '/app/static/js/display_utils.js'; +import { createEmitter, debounce, getValue, isEqual } from '/app/static/js/utils.js'; const { div, span } = van.tags; @@ -26,29 +30,88 @@ const { div, span } = van.tags; * @property {VanState} observability_api_url * @property {VanState} observability_api_key * @property {VanState} observability_test_results + * @property {VanState} data_retention_enabled + * @property {VanState} data_retention_days + * @property {VanState} retention_cron_expr + * @property {VanState} retention_cron_tz + * @property {VanState} retention_cron_sample + * @property {VanState} retention_last_run + * @property {VanState<{profiling_count: number, test_count: number}?>} retention_preview * * @param {Properties} props */ const ProjectSettings = (props) => { const { emit } = props; + // Persisted values are reactive: after a Save, the props update with the + // newly-stored values and these derives recompute, letting + // `showRetentionConfirmation` settle back to a clean state. + const browserTz = Intl.DateTimeFormat().resolvedOptions().timeZone || 'UTC'; + const persistedRetentionEnabled = van.derive(() => props.data_retention_enabled.val ?? false); + const persistedRetentionDays = van.derive(() => props.data_retention_days.val ?? 180); + const persistedRetentionCron = van.derive(() => props.retention_cron_expr.val ?? '0 1 * * *'); + const persistedRetentionTz = van.derive(() => props.retention_cron_tz.val ?? browserTz); const /** @type Properties */ form = { name: van.state(props.name.rawVal ?? ''), use_dq_score_weights: van.state(props.use_dq_score_weights.rawVal ?? true), observability_api_key: van.state(props.observability_api_key.rawVal ?? ''), observability_api_url: van.state(props.observability_api_url.rawVal ?? ''), + data_retention_enabled: van.state(persistedRetentionEnabled.val), + data_retention_days: van.state(persistedRetentionDays.val), + retention_cron_expr: van.state(persistedRetentionCron.val), + retention_cron_tz: van.state(persistedRetentionTz.val), }; const formValidity = { name: van.state(!!form.name.rawVal), observability_api_key: van.state(true), observability_api_url: van.state(true), + data_retention_days: van.state(Number.isFinite(form.data_retention_days.rawVal)), }; - const saveDisabled = van.derive(() => !formValidity.name.val || !formValidity.observability_api_url.val || !formValidity.observability_api_key.val); + const saveDisabled = van.derive(() => !formValidity.name.val + || !formValidity.observability_api_url.val + || !formValidity.observability_api_key.val + || (form.data_retention_enabled.val && !formValidity.data_retention_days.val)); const testObservabilityDisabled = van.derive(() => form.observability_api_url.val.length <= 0 || form.observability_api_key.val.length <= 0); + const retentionCronEditorValue = van.derive(() => { + if (form.retention_cron_expr.val && form.retention_cron_tz.val && form.data_retention_enabled.val) { + emit('GetCronSample', { + payload: { cron_expr: form.retention_cron_expr.val, tz: form.retention_cron_tz.val }, + }); + } + return { + timezone: form.retention_cron_tz.val, + expression: form.retention_cron_expr.val, + }; + }); + // True when the form would enlarge the next cleanup's delete set — + // turning retention on, or shortening the retention period of a project + // that already has it on. Both cases warrant a delete-preview confirmation + // before saving. + const showRetentionConfirmation = van.derive(() => { + if (!form.data_retention_enabled.val) return false; + if (!persistedRetentionEnabled.val) return true; + return form.data_retention_days.val < persistedRetentionDays.val; + }); + // Debounce so rapid days edits collapse to a single round-trip. + const previewPending = van.state(false); + const emitPreviewRequest = debounce((days) => { + emit('GetRetentionPreview', { payload: { retention_days: days } }); + }, 300); + van.derive(() => { + if (showRetentionConfirmation.val && formValidity.data_retention_days.val) { + previewPending.val = true; + emitPreviewRequest(form.data_retention_days.val); + } + }); + van.derive(() => { + if (getValue(props.retention_preview) !== null && getValue(props.retention_preview) !== undefined) { + previewPending.val = false; + } + }); return div( { class: 'flex-column fx-gap-3' }, div( - { class: 'flex-column fx-gap-1' }, + { class: 'flex-column fx-gap-1', style: 'max-width: 700px;' }, span({ class: 'body m' }, 'Project Info'), Card({ class: 'mb-0', @@ -74,7 +137,7 @@ const ProjectSettings = (props) => { }), ), div( - { class: 'flex-column fx-gap-1' }, + { class: 'flex-column fx-gap-1', style: 'max-width: 700px;' }, span({ class: 'body m' }, 'Observability Integration'), Card({ class: 'mb-0', @@ -129,6 +192,97 @@ const ProjectSettings = (props) => { ), }), ), + div( + { class: 'flex-column fx-gap-1', style: 'max-width: 700px;' }, + span({ class: 'body m' }, 'Data Retention'), + Card({ + class: 'mb-0', + border: true, + content: div( + { class: 'flex-column fx-gap-3' }, + Checkbox({ + label: 'Enable data retention', + checked: form.data_retention_enabled, + help: 'Automatically delete old profiling and test run data to keep your database lean. The most recent run in each suite or table group is always preserved.', + onChange: (checked) => { form.data_retention_enabled.val = checked; }, + }), + () => form.data_retention_enabled.val + ? div( + { class: 'flex-column fx-gap-3' }, + Input({ + label: 'Retention period (days)', + value: form.data_retention_days, + type: 'number', + step: 1, + validators: [ required, numberBetween(30, 9999, 0) ], + onChange: (value, validity) => { + form.data_retention_days.val = value === '' ? NaN : parseInt(value); + formValidity.data_retention_days.val = validity.valid; + }, + }), + () => { + const days = form.data_retention_days.val; + return days >= 30 && days < 60 + ? span( + { class: 'text-caption', style: 'color: var(--purple);' }, + 'Monitors perform better with more historical data — at least two months is recommended.', + ) + : ''; + }, + div( + { class: 'flex-row fx-gap-3 fx-flex-wrap fx-align-flex-start' }, + () => Select({ + label: 'Timezone', + options: timezones.map((tz_) => ({ label: tz_, value: tz_ })), + value: form.retention_cron_tz, + allowNull: false, + filterable: true, + onChange: (value) => { form.retention_cron_tz.val = value; }, + portalClass: 'short-select-portal', + style: 'flex: auto;', + }), + div( + { style: 'flex: auto;' }, + CrontabInput({ + emit, + name: 'data_retention_schedule', + sample: props.retention_cron_sample, + value: retentionCronEditorValue, + modes: ['x_hours', 'x_days'], + hideExpression: true, + onChange: (value) => { form.retention_cron_expr.val = value; }, + }), + ), + ), + () => { + const lastRun = getValue(props.retention_last_run); + const sample = getValue(props.retention_cron_sample) ?? {}; + const nextSample = (sample.samples ?? [])[0]; + return div( + { class: 'flex-column fx-gap-1 text-caption' }, + span(`Last cleanup ran: ${lastRun ? formatTimestamp(lastRun) : 'never'}`), + nextSample ? span(`Next cleanup: ${nextSample}`) : '', + ); + }, + () => { + if (!showRetentionConfirmation.val) return ''; + const preview = getValue(props.retention_preview); + const profilingCt = preview?.profiling_count ?? 0; + const testCt = preview?.test_count ?? 0; + const showing = preview !== null && preview !== undefined && !previewPending.val; + const message = !showing + ? 'Calculating impact…' + : `This will delete approximately ${profilingCt} profiling run${profilingCt === 1 ? '' : 's'} and ${testCt} test run${testCt === 1 ? '' : 's'} at the next cleanup. Deleted data cannot be recovered.`; + return Alert( + { type: 'warn' }, + span(message), + ); + }, + ) + : '', + ), + }), + ), div( { class: 'flex-row fx-justify-content-flex-end' }, Button({ diff --git a/testgen/ui/views/project_settings.py b/testgen/ui/views/project_settings.py index 7cc37900..df188e2a 100644 --- a/testgen/ui/views/project_settings.py +++ b/testgen/ui/views/project_settings.py @@ -1,19 +1,31 @@ import random import typing from dataclasses import asdict, dataclass, field +from datetime import UTC, datetime, timedelta import streamlit as st +from sqlalchemy import select from testgen.commands.run_observability_exporter import test_observability_exporter -from testgen.common.enums import JobSource -from testgen.common.models import with_database_session +from testgen.common.enums import JobKey, JobSource, JobStatus +from testgen.common.models import database_session, with_database_session from testgen.common.models.job_execution import JobExecution +from testgen.common.models.profiling_run import ProfilingRun from testgen.common.models.project import Project +from testgen.common.models.scheduler import ( + DEFAULT_DATA_CLEANUP_CRON, + DEFAULT_RETENTION_CRON_TZ, + JobSchedule, +) +from testgen.common.models.test_run import TestRun from testgen.ui.components import widgets as testgen from testgen.ui.navigation.menu import MenuItem from testgen.ui.navigation.page import Page from testgen.ui.services.query_cache import get_project, select_projects_where from testgen.ui.session import session, temp_value +from testgen.ui.utils import get_cron_sample_handler + +DEFAULT_RETENTION_DAYS = 180 PAGE_TITLE = "Project Settings" @@ -37,6 +49,11 @@ class ProjectSettingsPage(Page): def render(self, project_code: str | None = None, **_kwargs) -> None: self.project = get_project(project_code) + retention_schedule = JobSchedule.get( + JobSchedule.project_code == project_code, + JobSchedule.key == JobKey.run_data_cleanup, + ) + retention_last_run = self._get_last_cleanup_timestamp(project_code) testgen.page_header( PAGE_TITLE, @@ -44,11 +61,21 @@ def render(self, project_code: str | None = None, **_kwargs) -> None: ) get_test_results, set_test_results = temp_value(f"project_settings:{project_code}", default=None) + cron_sample_result, on_cron_sample = get_cron_sample_handler( + f"project_settings:cron_sample:{project_code}", sample_count=2, + ) + # Persistent session_state (not pop-on-read) so rapid days edits don't lose the response. + retention_preview_key = f"project_settings:retention_preview:{project_code}" def on_observability_connection_test(payload: dict) -> None: results = self.test_observability_connection(project_code, payload) set_test_results(asdict(results)) + def on_retention_preview(payload: dict) -> None: + st.session_state[retention_preview_key] = self._get_retention_preview( + project_code, payload.get("retention_days"), + ) + return testgen.project_settings( key="project_settings", data={ @@ -57,11 +84,54 @@ def on_observability_connection_test(payload: dict) -> None: "observability_api_url": self.project.observability_api_url, "observability_api_key": self.project.observability_api_key, "observability_test_results": get_test_results(), + "data_retention_enabled": self.project.data_retention_enabled, + "data_retention_days": self.project.data_retention_days or DEFAULT_RETENTION_DAYS, + "retention_cron_expr": retention_schedule.cron_expr if retention_schedule else DEFAULT_DATA_CLEANUP_CRON, + "retention_cron_tz": retention_schedule.cron_tz if retention_schedule else None, + "retention_cron_sample": cron_sample_result(), + "retention_last_run": int(retention_last_run.timestamp() * 1000) if retention_last_run else None, + "retention_preview": st.session_state.get(retention_preview_key), }, on_TestObservabilityClicked_change=on_observability_connection_test, + on_GetCronSample_change=on_cron_sample, + on_GetRetentionPreview_change=on_retention_preview, on_SaveClicked_change=lambda payload: self.update_project(project_code, payload), ) + @staticmethod + def _get_last_cleanup_timestamp(project_code: str) -> datetime | None: + with database_session() as session_: + return session_.scalar( + select(JobExecution.completed_at) + .where( + JobExecution.project_code == project_code, + JobExecution.job_key == JobKey.run_data_cleanup, + JobExecution.status == JobStatus.COMPLETED, + JobExecution.completed_at.isnot(None), + ) + .order_by(JobExecution.completed_at.desc()) + .limit(1) + ) + + @staticmethod + def _get_retention_preview(project_code: str, retention_days: int | None) -> dict | None: + if not retention_days or retention_days < 1: + return None + cutoff = datetime.now(UTC) - timedelta(days=retention_days) + with database_session(): + protected_profiling_ids = ProfilingRun.find_latest_per_table_group(project_code) + protected_test_ids = TestRun.find_latest_per_test_suite(project_code) + return { + "profiling_count": ProfilingRun.delete_older_than( + cutoff, project_code, protected_profiling_ids, dry_run=True, + ), + "test_count": TestRun.delete_older_than( + cutoff, project_code, protected_test_ids, dry_run=True, + ), + # Tiebreaker: identical counts for different days otherwise deep-equal and suppress the prop update. + "_": random.random(), # noqa: S311 + } + @with_database_session def update_project(self, project_code: str, edited_project: dict) -> None: existing_names = [ @@ -77,11 +147,26 @@ def update_project(self, project_code: str, edited_project: dict) -> None: self.project.use_dq_score_weights = edited_project.get("use_dq_score_weights", True) self.project.observability_api_url = edited_project.get("observability_api_url") self.project.observability_api_key = edited_project.get("observability_api_key") + + retention_enabled = bool(edited_project.get("data_retention_enabled")) + retention_days = edited_project.get("data_retention_days") or DEFAULT_RETENTION_DAYS + self.project.data_retention_enabled = retention_enabled + self.project.data_retention_days = retention_days if retention_enabled else None self.project.save() + if retention_enabled: + JobSchedule.upsert_for_retention( + project_code=project_code, + retention_days=retention_days, + cron_expr=edited_project.get("retention_cron_expr") or DEFAULT_DATA_CLEANUP_CRON, + cron_tz=edited_project.get("retention_cron_tz") or DEFAULT_RETENTION_CRON_TZ, + ) + else: + JobSchedule.delete_for_retention(project_code) + if weights_changed: JobExecution.submit( - job_key="recalculate-project-scores", + job_key=JobKey.recalculate_project_scores, kwargs={"project_code": project_code}, source=JobSource.ui, project_code=project_code, diff --git a/tests/unit/commands/test_exec_job.py b/tests/unit/commands/test_exec_job.py index 21ae898a..56df805f 100644 --- a/tests/unit/commands/test_exec_job.py +++ b/tests/unit/commands/test_exec_job.py @@ -4,7 +4,8 @@ import pytest from testgen.commands.exec_job import exec_job -from testgen.commands.job_registry import JOB_DISPATCH, JOB_FINAL_CALLBACKS, run_final_callbacks +from testgen.commands.job_registry import JOB_DISPATCH, JOB_FINAL_CALLBACKS, JobConfig, run_final_callbacks +from testgen.common.enums import JobKey from testgen.common.models.job_execution import JobExecution pytestmark = pytest.mark.unit @@ -19,7 +20,7 @@ def mock_session(): yield session -def _make_job_exec(job_key="run-tests", status="claimed", **kwargs): +def _make_job_exec(job_key=JobKey.run_tests, status="claimed", **kwargs): job = MagicMock(spec=JobExecution) job.id = uuid4() job.job_key = job_key @@ -31,13 +32,13 @@ def _make_job_exec(job_key="run-tests", status="claimed", **kwargs): def test_exec_job_dispatches_run_tests(mock_session): - job = _make_job_exec(job_key="run-tests") + job = _make_job_exec(job_key=JobKey.run_tests) job.mark_running.return_value = True dispatch_mock = Mock(return_value="ok") with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-tests": dispatch_mock}), + patch.dict(JOB_DISPATCH, {JobKey.run_tests: JobConfig(handler=dispatch_mock)}), ): exec_job(job.id) @@ -47,14 +48,14 @@ def test_exec_job_dispatches_run_tests(mock_session): def test_exec_job_dispatches_run_profile(mock_session): - job = _make_job_exec(job_key="run-profile") + job = _make_job_exec(job_key=JobKey.run_profile) job.kwargs = {"table_group_id": "tg-123"} job.mark_running.return_value = True dispatch_mock = Mock(return_value="ok") with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-profile": dispatch_mock}), + patch.dict(JOB_DISPATCH, {JobKey.run_profile: JobConfig(handler=dispatch_mock)}), ): exec_job(job.id) @@ -63,13 +64,13 @@ def test_exec_job_dispatches_run_profile(mock_session): def test_exec_job_dispatches_run_monitors(mock_session): - job = _make_job_exec(job_key="run-monitors") + job = _make_job_exec(job_key=JobKey.run_monitors) job.mark_running.return_value = True dispatch_mock = Mock(return_value="ok") with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-monitors": dispatch_mock}), + patch.dict(JOB_DISPATCH, {JobKey.run_monitors: JobConfig(handler=dispatch_mock)}), ): exec_job(job.id) @@ -77,14 +78,14 @@ def test_exec_job_dispatches_run_monitors(mock_session): def test_exec_job_dispatches_run_test_generation(mock_session): - job = _make_job_exec(job_key="run-test-generation") + job = _make_job_exec(job_key=JobKey.run_test_generation) job.kwargs = {"test_suite_id": "suite-123", "generation_set": "Standard"} job.mark_running.return_value = True dispatch_mock = Mock(return_value="ok") with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-test-generation": dispatch_mock}), + patch.dict(JOB_DISPATCH, {JobKey.run_test_generation: JobConfig(handler=dispatch_mock)}), ): exec_job(job.id) @@ -103,7 +104,7 @@ def test_exec_job_marks_interrupted_on_unknown_key(mock_session): def test_exec_job_skips_when_mark_running_fails(mock_session): - job = _make_job_exec(job_key="run-tests") + job = _make_job_exec(job_key=JobKey.run_tests) job.mark_running.return_value = False with patch.object(JobExecution, "get", return_value=job): @@ -113,12 +114,12 @@ def test_exec_job_skips_when_mark_running_fails(mock_session): def test_exec_job_marks_interrupted_on_dispatch_error(mock_session): - job = _make_job_exec(job_key="run-tests") + job = _make_job_exec(job_key=JobKey.run_tests) job.mark_running.return_value = True with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-tests": Mock(side_effect=RuntimeError("boom"))}), + patch.dict(JOB_DISPATCH, {JobKey.run_tests: JobConfig(handler=Mock(side_effect=RuntimeError("boom")))}), ): exec_job(job.id) @@ -136,24 +137,24 @@ def test_exec_job_exits_on_missing_record(mock_session): def test_job_dispatch_has_all_job_keys(): - assert "run-profile" in JOB_DISPATCH - assert "run-tests" in JOB_DISPATCH - assert "run-monitors" in JOB_DISPATCH - assert "run-test-generation" in JOB_DISPATCH - assert "run-score-update" in JOB_DISPATCH - assert "recalculate-project-scores" in JOB_DISPATCH + assert JobKey.run_profile in JOB_DISPATCH + assert JobKey.run_tests in JOB_DISPATCH + assert JobKey.run_monitors in JOB_DISPATCH + assert JobKey.run_test_generation in JOB_DISPATCH + assert JobKey.run_score_update in JOB_DISPATCH + assert JobKey.recalculate_project_scores in JOB_DISPATCH def test_exec_job_fires_final_callbacks_on_success(mock_session): - job = _make_job_exec(job_key="run-tests") + job = _make_job_exec(job_key=JobKey.run_tests) job.mark_running.return_value = True job.mark_completed.return_value = True cb1, cb2 = Mock(), Mock() with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-tests": Mock(return_value="ok")}), - patch.dict(JOB_FINAL_CALLBACKS, {"run-tests": [cb1, cb2]}), + patch.dict(JOB_DISPATCH, {JobKey.run_tests: JobConfig(handler=Mock(return_value="ok"))}), + patch.dict(JOB_FINAL_CALLBACKS, {JobKey.run_tests: [cb1, cb2]}), ): exec_job(job.id) @@ -162,7 +163,7 @@ def test_exec_job_fires_final_callbacks_on_success(mock_session): def test_exec_job_runs_callbacks_in_registered_order(mock_session): - job = _make_job_exec(job_key="run-tests") + job = _make_job_exec(job_key=JobKey.run_tests) job.mark_running.return_value = True job.mark_completed.return_value = True order = [] @@ -171,8 +172,8 @@ def test_exec_job_runs_callbacks_in_registered_order(mock_session): with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-tests": Mock(return_value="ok")}), - patch.dict(JOB_FINAL_CALLBACKS, {"run-tests": [cb1, cb2]}), + patch.dict(JOB_DISPATCH, {JobKey.run_tests: JobConfig(handler=Mock(return_value="ok"))}), + patch.dict(JOB_FINAL_CALLBACKS, {JobKey.run_tests: [cb1, cb2]}), ): exec_job(job.id) @@ -180,15 +181,15 @@ def test_exec_job_runs_callbacks_in_registered_order(mock_session): def test_exec_job_skips_callbacks_when_mark_completed_fails(mock_session): - job = _make_job_exec(job_key="run-tests") + job = _make_job_exec(job_key=JobKey.run_tests) job.mark_running.return_value = True job.mark_completed.return_value = False cb = Mock() with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-tests": Mock(return_value="ok")}), - patch.dict(JOB_FINAL_CALLBACKS, {"run-tests": [cb]}), + patch.dict(JOB_DISPATCH, {JobKey.run_tests: JobConfig(handler=Mock(return_value="ok"))}), + patch.dict(JOB_FINAL_CALLBACKS, {JobKey.run_tests: [cb]}), ): exec_job(job.id) @@ -196,15 +197,15 @@ def test_exec_job_skips_callbacks_when_mark_completed_fails(mock_session): def test_exec_job_fires_callbacks_on_interrupted(mock_session): - job = _make_job_exec(job_key="run-tests") + job = _make_job_exec(job_key=JobKey.run_tests) job.mark_running.return_value = True job.mark_interrupted.return_value = True cb = Mock() with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-tests": Mock(side_effect=RuntimeError("boom"))}), - patch.dict(JOB_FINAL_CALLBACKS, {"run-tests": [cb]}), + patch.dict(JOB_DISPATCH, {JobKey.run_tests: JobConfig(handler=Mock(side_effect=RuntimeError("boom")))}), + patch.dict(JOB_FINAL_CALLBACKS, {JobKey.run_tests: [cb]}), ): exec_job(job.id) @@ -212,15 +213,15 @@ def test_exec_job_fires_callbacks_on_interrupted(mock_session): def test_exec_job_skips_callbacks_when_mark_interrupted_fails(mock_session): - job = _make_job_exec(job_key="run-tests") + job = _make_job_exec(job_key=JobKey.run_tests) job.mark_running.return_value = True job.mark_interrupted.return_value = False cb = Mock() with ( patch.object(JobExecution, "get", return_value=job), - patch.dict(JOB_DISPATCH, {"run-tests": Mock(side_effect=RuntimeError("boom"))}), - patch.dict(JOB_FINAL_CALLBACKS, {"run-tests": [cb]}), + patch.dict(JOB_DISPATCH, {JobKey.run_tests: JobConfig(handler=Mock(side_effect=RuntimeError("boom")))}), + patch.dict(JOB_FINAL_CALLBACKS, {JobKey.run_tests: [cb]}), ): exec_job(job.id) @@ -228,11 +229,11 @@ def test_exec_job_skips_callbacks_when_mark_interrupted_fails(mock_session): def test_run_final_callbacks_isolates_failures(): - job = _make_job_exec(job_key="run-tests") + job = _make_job_exec(job_key=JobKey.run_tests) failing = Mock(side_effect=RuntimeError("boom"), __name__="failing_cb") succeeding = Mock(__name__="succeeding_cb") - with patch.dict(JOB_FINAL_CALLBACKS, {"run-tests": [failing, succeeding]}): + with patch.dict(JOB_FINAL_CALLBACKS, {JobKey.run_tests: [failing, succeeding]}): run_final_callbacks(job) failing.assert_called_once_with(job) @@ -247,6 +248,6 @@ def test_run_final_callbacks_noop_for_unknown_job_key(): def test_registered_callbacks_cover_notification_job_keys(): - assert "run-profile" in JOB_FINAL_CALLBACKS - assert "run-tests" in JOB_FINAL_CALLBACKS - assert "run-monitors" in JOB_FINAL_CALLBACKS + assert JobKey.run_profile in JOB_FINAL_CALLBACKS + assert JobKey.run_tests in JOB_FINAL_CALLBACKS + assert JobKey.run_monitors in JOB_FINAL_CALLBACKS diff --git a/tests/unit/commands/test_run_data_cleanup.py b/tests/unit/commands/test_run_data_cleanup.py new file mode 100644 index 00000000..6c6e4304 --- /dev/null +++ b/tests/unit/commands/test_run_data_cleanup.py @@ -0,0 +1,265 @@ +from datetime import UTC, datetime, timedelta +from unittest.mock import MagicMock, Mock, patch +from uuid import uuid4 + +import pytest + +from testgen.commands.run_data_cleanup import BATCH_SIZE, run_data_cleanup + +pytestmark = pytest.mark.unit + +MODULE = "testgen.commands.run_data_cleanup" + + +def _db_ctx(): + """Mock database_session() that yields nothing useful — the orchestrator's + nested with-blocks just need the context manager to enter/exit cleanly.""" + ctx = MagicMock() + ctx.__enter__ = Mock(return_value=MagicMock()) + ctx.__exit__ = Mock(return_value=False) + return ctx + + +def _patch_orchestrator( + protected_profiling: set | None = None, + protected_tests: set | None = None, + protected_profiling_jes: set | None = None, + protected_test_jes: set | None = None, + protected_history_keys: set | None = None, + deleted_profiling: int = 0, + deleted_tests: int = 0, + deleted_job_executions: int = 0, + deleted_score_history: int = 0, + deleted_score_latest: int = 0, + deleted_stg: tuple[int, int, int, int] = (0, 0, 0, 0), +): + """One-stop helper: patches every collaborator the orchestrator touches. + + Returns a dict of the patch mocks so individual tests can assert call shape. + """ + patches = { + "database_session": patch(f"{MODULE}.database_session", side_effect=lambda: _db_ctx()), + "ProfilingRun": patch(f"{MODULE}.ProfilingRun"), + "TestRun": patch(f"{MODULE}.TestRun"), + "JobExecution": patch(f"{MODULE}.JobExecution"), + "ScoreHistoryLatestRun": patch(f"{MODULE}.ScoreHistoryLatestRun"), + "ScoreDefinitionResultHistoryEntry": patch(f"{MODULE}.ScoreDefinitionResultHistoryEntry"), + "StgSecondaryProfileUpdate": patch(f"{MODULE}.StgSecondaryProfileUpdate"), + "StgFunctionalTableUpdate": patch(f"{MODULE}.StgFunctionalTableUpdate"), + "StgDataCharsUpdate": patch(f"{MODULE}.StgDataCharsUpdate"), + "StgTestDefinitionUpdate": patch(f"{MODULE}.StgTestDefinitionUpdate"), + } + started = {name: p.start() for name, p in patches.items()} + + started["ProfilingRun"].find_latest_per_table_group.return_value = protected_profiling or set() + # get_job_execution_ids returns dict[run_id, je_id]; orchestrator filters nulls. + started["ProfilingRun"].get_job_execution_ids.return_value = { + uuid4(): je_id for je_id in (protected_profiling_jes or set()) + } + started["ProfilingRun"].delete_older_than.return_value = deleted_profiling + + started["TestRun"].find_latest_per_test_suite.return_value = protected_tests or set() + started["TestRun"].get_job_execution_ids.return_value = { + uuid4(): je_id for je_id in (protected_test_jes or set()) + } + started["TestRun"].delete_older_than.return_value = deleted_tests + + started["JobExecution"].delete_older_than.return_value = deleted_job_executions + + started["ScoreHistoryLatestRun"].find_protected_keys.return_value = protected_history_keys or set() + started["ScoreHistoryLatestRun"].delete_older_than.return_value = deleted_score_latest + started["ScoreDefinitionResultHistoryEntry"].delete_older_than.return_value = deleted_score_history + + started["StgSecondaryProfileUpdate"].delete_older_than.return_value = deleted_stg[0] + started["StgFunctionalTableUpdate"].delete_older_than.return_value = deleted_stg[1] + started["StgDataCharsUpdate"].delete_older_than.return_value = deleted_stg[2] + started["StgTestDefinitionUpdate"].delete_older_than.return_value = deleted_stg[3] + + return started, patches + + +def _stop(patches): + for p in patches.values(): + p.stop() + + +def test_computes_cutoff_from_retention_days(): + """Cutoff passed to delete_older_than is `now - retention_days` (UTC).""" + started, patches = _patch_orchestrator() + try: + before = datetime.now(UTC) + run_data_cleanup(project_code="proj", retention_days=30) + after = datetime.now(UTC) + finally: + _stop(patches) + + cutoff = started["ProfilingRun"].delete_older_than.call_args.kwargs["cutoff"] + expected_low = before - timedelta(days=30) + expected_high = after - timedelta(days=30) + assert expected_low <= cutoff <= expected_high + # Same cutoff threads through every sweep + assert started["TestRun"].delete_older_than.call_args.kwargs["cutoff"] == cutoff + assert started["JobExecution"].delete_older_than.call_args.kwargs["cutoff"] == cutoff + + +def test_passes_protected_profiling_ids_to_delete(): + """Latest-run-per-table-group set is computed once and threaded through to + ProfilingRun.delete_older_than as the carve-out.""" + protected = {uuid4(), uuid4(), uuid4()} + started, patches = _patch_orchestrator(protected_profiling=protected) + try: + run_data_cleanup(project_code="proj", retention_days=180) + finally: + _stop(patches) + + started["ProfilingRun"].find_latest_per_table_group.assert_called_once_with("proj") + assert started["ProfilingRun"].delete_older_than.call_args.kwargs["protected_ids"] == protected + + +def test_passes_protected_test_run_ids_to_delete(): + """Latest-run-per-test-suite (incl. monitor suites) threads through to TestRun.delete_older_than.""" + protected = {uuid4(), uuid4()} + started, patches = _patch_orchestrator(protected_tests=protected) + try: + run_data_cleanup(project_code="proj", retention_days=180) + finally: + _stop(patches) + + started["TestRun"].find_latest_per_test_suite.assert_called_once_with("proj") + assert started["TestRun"].delete_older_than.call_args.kwargs["protected_ids"] == protected + + +def test_protected_job_execution_ids_is_union_of_run_je_ids(): + """JobExecution sweep carve-out = union of protected profiling + test run JE ids.""" + profiling_jes = {uuid4(), uuid4()} + test_jes = {uuid4()} + started, patches = _patch_orchestrator( + protected_profiling_jes=profiling_jes, + protected_test_jes=test_jes, + ) + try: + run_data_cleanup(project_code="proj", retention_days=180) + finally: + _stop(patches) + + passed = started["JobExecution"].delete_older_than.call_args.kwargs["protected_ids"] + assert passed == profiling_jes | test_jes + + +def test_score_history_uses_protected_keys_from_latest_runs(): + """find_protected_keys runs once with both run-id sets, and its result feeds + BOTH score-history sweeps (history entries + latest-runs mapping).""" + keys = {(uuid4(), datetime(2026, 1, 1)), (uuid4(), datetime(2026, 2, 1))} + profiling_ids = {uuid4()} + test_ids = {uuid4()} + started, patches = _patch_orchestrator( + protected_profiling=profiling_ids, + protected_tests=test_ids, + protected_history_keys=keys, + ) + try: + run_data_cleanup(project_code="proj", retention_days=180) + finally: + _stop(patches) + + started["ScoreHistoryLatestRun"].find_protected_keys.assert_called_once_with( + protected_profiling_ids=profiling_ids, + protected_test_run_ids=test_ids, + ) + assert started["ScoreDefinitionResultHistoryEntry"].delete_older_than.call_args.kwargs["protected_keys"] == keys + assert started["ScoreHistoryLatestRun"].delete_older_than.call_args.kwargs["protected_keys"] == keys + + +def test_staging_sweeps_get_no_carve_out(): + """All 4 staging models receive only cutoff + project_code — no protected_ids + arg (these tables have no per-run linkage).""" + started, patches = _patch_orchestrator() + try: + run_data_cleanup(project_code="proj", retention_days=180) + finally: + _stop(patches) + + for stg_name in [ + "StgSecondaryProfileUpdate", + "StgFunctionalTableUpdate", + "StgDataCharsUpdate", + "StgTestDefinitionUpdate", + ]: + call = started[stg_name].delete_older_than.call_args + # Positional args only: (cutoff, project_code) + assert len(call.args) == 2 + assert call.args[1] == "proj" + assert "protected_ids" not in call.kwargs + assert "protected_keys" not in call.kwargs + + +def test_batch_size_threaded_through(): + """The orchestrator's BATCH_SIZE constant is passed to every batch-capable sweep.""" + started, patches = _patch_orchestrator() + try: + run_data_cleanup(project_code="proj", retention_days=180) + finally: + _stop(patches) + + for collaborator, method in [ + ("ProfilingRun", "delete_older_than"), + ("TestRun", "delete_older_than"), + ("JobExecution", "delete_older_than"), + ("ScoreDefinitionResultHistoryEntry", "delete_older_than"), + ("ScoreHistoryLatestRun", "delete_older_than"), + ]: + kwargs = getattr(started[collaborator], method).call_args.kwargs + assert kwargs["batch_size"] == BATCH_SIZE, f"{collaborator}.{method} missing batch_size" + + +def test_summary_log_has_all_counts(caplog): + """The trailing summary log line includes the count from every sweep so the + operator can correlate what was deleted in a single grep.""" + import logging + caplog.set_level(logging.INFO, logger="testgen") + + started, patches = _patch_orchestrator( + deleted_profiling=10, + deleted_tests=20, + deleted_job_executions=30, + deleted_score_history=40, + deleted_score_latest=50, + deleted_stg=(1, 2, 3, 4), # sums to 10 + ) + try: + run_data_cleanup(project_code="proj", retention_days=180) + finally: + _stop(patches) + + summary = [r for r in caplog.records if "Data retention cleanup complete" in r.getMessage()] + assert len(summary) == 1 + msg = summary[0].getMessage() + assert "deleted_profiling=10" in msg + assert "deleted_tests=20" in msg + assert "deleted_job_executions=30" in msg + assert "deleted_score_history=40" in msg + assert "deleted_score_latest=50" in msg + assert "deleted_staging=10" in msg # sum of staging counts + + +def test_no_data_to_delete_runs_clean(): + """Empty everywhere: handler completes without error, all sweeps still invoked.""" + started, patches = _patch_orchestrator() + try: + run_data_cleanup(project_code="proj", retention_days=180) + finally: + _stop(patches) + + # Every sweep was still called (cleanup is unconditional once the schedule fires) + started["ProfilingRun"].delete_older_than.assert_called_once() + started["TestRun"].delete_older_than.assert_called_once() + started["JobExecution"].delete_older_than.assert_called_once() + started["ScoreDefinitionResultHistoryEntry"].delete_older_than.assert_called_once() + started["ScoreHistoryLatestRun"].delete_older_than.assert_called_once() + for stg in [ + "StgSecondaryProfileUpdate", + "StgFunctionalTableUpdate", + "StgDataCharsUpdate", + "StgTestDefinitionUpdate", + ]: + started[stg].delete_older_than.assert_called_once() diff --git a/tests/unit/common/models/test_job_execution.py b/tests/unit/common/models/test_job_execution.py index 2ff28c18..7ee152b7 100644 --- a/tests/unit/common/models/test_job_execution.py +++ b/tests/unit/common/models/test_job_execution.py @@ -1,9 +1,10 @@ +from datetime import UTC, datetime, timedelta from unittest.mock import MagicMock, Mock, patch from uuid import uuid4 import pytest -from testgen.common.models.job_execution import JobExecution +from testgen.common.models.job_execution import JobExecution, JobStatus pytestmark = pytest.mark.unit @@ -21,7 +22,13 @@ def _returning_row(job, **overrides): @pytest.fixture def mock_session(): session = MagicMock() - with patch(f"{MODULE}.get_current_session", return_value=session): + ctx = MagicMock() + ctx.__enter__ = Mock(return_value=session) + ctx.__exit__ = Mock(return_value=False) + with ( + patch(f"{MODULE}.get_current_session", return_value=session), + patch(f"{MODULE}.database_session", return_value=ctx), + ): yield session @@ -185,3 +192,97 @@ def test_request_cancel_terminal_state_returns_false(mock_session): assert job.request_cancel() is False assert job.status == "completed" + + +# ─── delete_older_than (data retention) ───────────────────────────── + + +def _capture_clauses_used_in_select(mock_session): + """Returns the WHERE clauses passed to the candidate-id select query. + + The cleanup loop does select(id).where(*clauses).limit(...). We capture + those clauses to assert which filters were applied.""" + select_call = mock_session.scalars.call_args + select_stmt = select_call.args[0] + return list(select_stmt.whereclause.clauses) if select_stmt.whereclause is not None else [] + + +def test_delete_older_than_filters_only_terminal_statuses(mock_session): + """The status filter is `IN ('completed', 'error', 'canceled')` — non-terminal + rows (pending/claimed/running/cancel_requested) are skipped regardless of age. + This is the key safety guarantee: live work must never be deleted.""" + mock_session.scalars.return_value.all.return_value = [] # no candidates → loop exits + + cutoff = datetime.now(UTC) - timedelta(days=180) + JobExecution.delete_older_than(cutoff=cutoff, project_code="proj", protected_ids=set()) + + clauses = _capture_clauses_used_in_select(mock_session) + status_clause = next( + (c for c in clauses if "status" in str(c).lower()), + None, + ) + assert status_clause is not None + rendered = str(status_clause.compile(compile_kwargs={"literal_binds": True})) + # Must include all three terminal states + for state in (JobStatus.COMPLETED.value, JobStatus.ERROR.value, JobStatus.CANCELED.value): + assert state in rendered + # Must not include any non-terminal state + for state in (JobStatus.PENDING.value, JobStatus.CLAIMED.value, + JobStatus.RUNNING.value, JobStatus.CANCEL_REQUESTED.value): + assert state not in rendered + + +def test_delete_older_than_returns_zero_when_no_candidates(mock_session): + """No-op when nothing is old enough to delete — returns 0, no DELETE executed.""" + mock_session.scalars.return_value.all.return_value = [] + + cutoff = datetime.now(UTC) - timedelta(days=180) + result = JobExecution.delete_older_than(cutoff=cutoff, project_code="proj", protected_ids=set()) + + assert result == 0 + # Only the candidate-select ran; no DELETE statement was issued. + mock_session.execute.assert_not_called() + + +def test_delete_older_than_batches_and_deletes(mock_session): + """Two-batch path: scalars returns one batch, then empty. Both should result + in a DELETE on the first batch, and the total count returned.""" + first_batch = [uuid4(), uuid4(), uuid4()] + mock_session.scalars.return_value.all.side_effect = [first_batch, []] + + cutoff = datetime.now(UTC) - timedelta(days=180) + result = JobExecution.delete_older_than( + cutoff=cutoff, project_code="proj", protected_ids=set(), batch_size=1000, + ) + + assert result == 3 + mock_session.execute.assert_called_once() # one DELETE for one non-empty batch + + +def test_delete_older_than_applies_protected_ids_exclusion(mock_session): + """The protected_ids carve-out — job_executions of protected runs — adds a + NOT IN clause so they survive even when older than the cutoff.""" + protected = {uuid4(), uuid4()} + mock_session.scalars.return_value.all.return_value = [] + + cutoff = datetime.now(UTC) - timedelta(days=180) + JobExecution.delete_older_than(cutoff=cutoff, project_code="proj", protected_ids=protected) + + clauses = _capture_clauses_used_in_select(mock_session) + rendered = " ".join(str(c) for c in clauses).lower() + assert "not in" in rendered or "!= all" in rendered or "in (" in rendered # NOT IN expression present + + +def test_delete_older_than_skips_protected_filter_when_empty(mock_session): + """Empty protected_ids → no NOT IN clause emitted, avoiding the SQL warning + that `IN ()` triggers in postgres.""" + mock_session.scalars.return_value.all.return_value = [] + + cutoff = datetime.now(UTC) - timedelta(days=180) + JobExecution.delete_older_than(cutoff=cutoff, project_code="proj", protected_ids=set()) + + clauses = _capture_clauses_used_in_select(mock_session) + rendered = " ".join(str(c) for c in clauses).lower() + # Three expected clauses: project_code, completed_at, status IN + # Absence of "not in" confirms the protected-ids clause was skipped. + assert "not in" not in rendered diff --git a/tests/unit/common/models/test_scheduler.py b/tests/unit/common/models/test_scheduler.py new file mode 100644 index 00000000..f06f1292 --- /dev/null +++ b/tests/unit/common/models/test_scheduler.py @@ -0,0 +1,137 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from testgen.common.enums import JobKey +from testgen.common.models.scheduler import ( + DEFAULT_DATA_CLEANUP_CRON, + JobSchedule, +) + +pytestmark = pytest.mark.unit + +MODULE = "testgen.common.models.scheduler" + + +@pytest.fixture +def mock_session(): + session = MagicMock() + with patch(f"{MODULE}.get_current_session", return_value=session): + yield session + + +# ─── upsert_for_retention ─────────────────────────────────────────── + + +def test_upsert_for_retention_inserts_when_missing(mock_session): + """No existing schedule for (project, JobKey.run_data_cleanup) → INSERT path: + creates a fresh JobSchedule and adds it to the session.""" + mock_session.scalars.return_value.first.return_value = None + + schedule = JobSchedule.upsert_for_retention( + project_code="proj", + retention_days=90, + cron_expr="0 1 * * *", + cron_tz="UTC", + ) + + mock_session.add.assert_called_once() + added = mock_session.add.call_args[0][0] + assert added is schedule + assert schedule.project_code == "proj" + assert schedule.key == JobKey.run_data_cleanup + assert schedule.kwargs == {"project_code": "proj", "retention_days": 90} + assert schedule.cron_expr == "0 1 * * *" + assert schedule.cron_tz == "UTC" + assert schedule.active is True + + +def test_upsert_for_retention_updates_when_present(mock_session): + """Existing schedule for the same (project, key) → UPDATE path: mutates in + place; does NOT add a new row (would otherwise violate the table's + UNIQUE constraint and duplicate schedules per project).""" + existing = JobSchedule( + project_code="proj", + key=JobKey.run_data_cleanup, + kwargs={"project_code": "proj", "retention_days": 180}, + cron_expr="0 1 * * *", + cron_tz="UTC", + active=False, + ) + mock_session.scalars.return_value.first.return_value = existing + + result = JobSchedule.upsert_for_retention( + project_code="proj", + retention_days=30, + cron_expr="0 2 * * *", + cron_tz="America/New_York", + ) + + mock_session.add.assert_not_called() + assert result is existing + assert existing.kwargs == {"project_code": "proj", "retention_days": 30} + assert existing.cron_expr == "0 2 * * *" + assert existing.cron_tz == "America/New_York" + # Re-activated even when the previous schedule had been deactivated + assert existing.active is True + + +def test_upsert_for_retention_reactivates_inactive_schedule(mock_session): + """A specific guard: if a project's retention schedule was disabled (active=False) + and the user re-enables retention, the upsert flips active back to True.""" + existing = JobSchedule( + project_code="proj", + key=JobKey.run_data_cleanup, + kwargs={}, + cron_expr="0 1 * * *", + cron_tz="UTC", + active=False, + ) + mock_session.scalars.return_value.first.return_value = existing + + JobSchedule.upsert_for_retention( + project_code="proj", + retention_days=180, + cron_expr=DEFAULT_DATA_CLEANUP_CRON, + cron_tz="UTC", + ) + + assert existing.active is True + + +def test_upsert_for_retention_does_not_commit(mock_session): + """Like other model methods: the helper participates in the caller's + transaction; it must not commit on its own. The save() path is owned by + the request scope (database_session or safe_rerun).""" + mock_session.scalars.return_value.first.return_value = None + + JobSchedule.upsert_for_retention( + project_code="proj", + retention_days=180, + cron_expr=DEFAULT_DATA_CLEANUP_CRON, + cron_tz="UTC", + ) + + mock_session.commit.assert_not_called() + + +# ─── delete_for_retention ─────────────────────────────────────────── + + +def test_delete_for_retention_executes_scoped_delete(mock_session): + """Issues a single DELETE filtered to (project_code, JobKey.run_data_cleanup). + Idempotent — safe to call when no schedule exists (mock_session.execute + is a no-op).""" + JobSchedule.delete_for_retention("proj") + + mock_session.execute.assert_called_once() + stmt = mock_session.execute.call_args.args[0] + rendered = str(stmt.compile(compile_kwargs={"literal_binds": True})) + assert "DELETE FROM job_schedules" in rendered + assert "proj" in rendered + assert JobKey.run_data_cleanup in rendered + + +def test_delete_for_retention_does_not_commit(mock_session): + JobSchedule.delete_for_retention("proj") + mock_session.commit.assert_not_called() diff --git a/tests/unit/scheduler/test_scheduler_cli.py b/tests/unit/scheduler/test_scheduler_cli.py index 8dfa045d..ed8f1556 100644 --- a/tests/unit/scheduler/test_scheduler_cli.py +++ b/tests/unit/scheduler/test_scheduler_cli.py @@ -7,6 +7,8 @@ import pytest +from testgen.commands.job_registry import JobConfig +from testgen.common.enums import JobKey, JobSource from testgen.common.models.job_execution import JobExecution from testgen.common.models.scheduler import JobSchedule from testgen.scheduler.base import DelayedPolicy @@ -49,7 +51,7 @@ def db_jobs(scheduler_instance): @pytest.fixture def job_data(): - with patch.dict("testgen.commands.job_registry.JOB_DISPATCH", {"test-job": Mock()}): + with patch.dict("testgen.commands.job_registry.JOB_DISPATCH", {"test-job": JobConfig(handler=Mock())}): yield { "cron_expr": "*/5 9-17 * * *", "cron_tz": "UTC", @@ -95,6 +97,32 @@ def test_job_start(scheduler_instance, cli_job): mock_session.commit.assert_called_once() +def test_job_start_tags_source_from_job_config(scheduler_instance, job_data): + """Scheduled executions inherit `JobConfig.scheduler_source` as their + `JobExecution.source`. Retention cleanup (registered with + `scheduler_source="system"`) gets `source="system"` so MCP / REST / + api/deps filters auto-hide it from user-facing surfaces.""" + from testgen.scheduler.cli_scheduler import CliJob + + system_job_data = {**job_data, "key": JobKey.run_data_cleanup} + system_cli_job = CliJob(**system_job_data, delayed_policy=DelayedPolicy.SKIP) + + mock_session = MagicMock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=False) + with ( + patch.dict( + "testgen.commands.job_registry.JOB_DISPATCH", + {JobKey.run_data_cleanup: JobConfig(handler=Mock(), scheduler_source=JobSource.system)}, + ), + patch("testgen.common.models.Session", return_value=mock_session), + ): + scheduler_instance.start_job(system_cli_job, datetime.now(UTC)) + + added = mock_session.add.call_args[0][0] + assert added.source == "system" + + @pytest.mark.parametrize("proc_exit_code", [0, 1]) def test_proc_wrapper_status(proc_exit_code, scheduler_instance): mock_session = MagicMock() diff --git a/tests/unit/scheduler/test_scheduler_poll.py b/tests/unit/scheduler/test_scheduler_poll.py index 6137c2f2..39082942 100644 --- a/tests/unit/scheduler/test_scheduler_poll.py +++ b/tests/unit/scheduler/test_scheduler_poll.py @@ -4,8 +4,8 @@ import pytest -from testgen.commands.job_registry import JOB_DISPATCH -from testgen.common.enums import JobStatus +from testgen.commands.job_registry import JOB_DISPATCH, JobConfig +from testgen.common.enums import JobKey, JobStatus from testgen.common.models.job_execution import JobExecution from testgen.scheduler.cli_scheduler import CliScheduler @@ -36,7 +36,7 @@ def scheduler_instance(): def job_exec(): return JobExecution( id=uuid4(), - job_key="run-tests", + job_key=JobKey.run_tests, kwargs={"test_suite_id": "suite-123"}, source="scheduler", status="claimed", @@ -56,7 +56,7 @@ def test_dispatch_spawns_process(scheduler_instance, job_exec, mock_session): proc_mock = MagicMock() with ( - patch.dict(JOB_DISPATCH, {"run-tests": Mock()}, clear=False), + patch.dict(JOB_DISPATCH, {JobKey.run_tests: JobConfig(handler=Mock())}, clear=False), patch(f"{SCHEDULER_MODULE}.subprocess.Popen", return_value=proc_mock) as popen_mock, patch(f"{SCHEDULER_MODULE}.threading.Thread") as thread_mock, ): @@ -218,7 +218,7 @@ def test_poll_loop_routes_cancel_requested(scheduler_instance, mock_session): """Cancel_requested rows are routed to _handle_cancellation, not _dispatch.""" cancel_job = JobExecution( id=uuid4(), - job_key="run-tests", + job_key=JobKey.run_tests, kwargs={}, source="ui", status=JobStatus.CANCEL_REQUESTED, @@ -254,7 +254,7 @@ def test_start_job_submits_execution(scheduler_instance, mock_session): cron_expr="*/5 * * * *", cron_tz="UTC", delayed_policy=DelayedPolicy.SKIP, - key="run-profile", + key=JobKey.run_profile, kwargs={"table_group_id": "tg-123"}, job_schedule_id=schedule_id, ) @@ -263,7 +263,7 @@ def test_start_job_submits_execution(scheduler_instance, mock_session): mock_session.add.assert_called_once() added = mock_session.add.call_args[0][0] - assert added.job_key == "run-profile" + assert added.job_key == JobKey.run_profile assert added.kwargs == {"table_group_id": "tg-123"} assert added.source == "scheduler" assert added.job_schedule_id == schedule_id diff --git a/tests/unit/ui/test_project_settings.py b/tests/unit/ui/test_project_settings.py index 08ab933b..d5acb6b2 100644 --- a/tests/unit/ui/test_project_settings.py +++ b/tests/unit/ui/test_project_settings.py @@ -2,6 +2,7 @@ import pytest +from testgen.common.enums import JobKey from testgen.ui.views.project_settings import ProjectSettingsPage pytestmark = pytest.mark.unit @@ -19,11 +20,13 @@ def mock_session(): yield session -def _make_page(use_dq_score_weights=True): +def _make_page(use_dq_score_weights=True, data_retention_enabled=False, data_retention_days=None): page = ProjectSettingsPage.__new__(ProjectSettingsPage) page.project = MagicMock() page.project.use_dq_score_weights = use_dq_score_weights page.project.project_name = "My Project" + page.project.data_retention_enabled = data_retention_enabled + page.project.data_retention_days = data_retention_days return page @@ -34,7 +37,7 @@ def test_update_project_submits_recalculate_job_when_weights_toggled_on(mock_ses page.update_project("proj", {"name": "My Project", "use_dq_score_weights": True}) mock_je.submit.assert_called_once_with( - job_key="recalculate-project-scores", + job_key=JobKey.recalculate_project_scores, kwargs={"project_code": "proj"}, source="ui", project_code="proj", @@ -48,7 +51,7 @@ def test_update_project_submits_recalculate_job_when_weights_toggled_off(mock_se page.update_project("proj", {"name": "My Project", "use_dq_score_weights": False}) mock_je.submit.assert_called_once_with( - job_key="recalculate-project-scores", + job_key=JobKey.recalculate_project_scores, kwargs={"project_code": "proj"}, source="ui", project_code="proj", @@ -86,3 +89,80 @@ def test_update_project_raises_on_duplicate_name(mock_session): ): mock_select.return_value = [MagicMock(project_name="Other Project")] page.update_project("proj", {"name": "Other Project", "use_dq_score_weights": True}) + + +# ─── Data retention ────────────────────────────────────────────────── + + +def test_update_project_upserts_schedule_when_retention_enabled(mock_session): + page = _make_page(data_retention_enabled=False) + payload = { + "name": "My Project", + "use_dq_score_weights": True, + "data_retention_enabled": True, + "data_retention_days": 90, + "retention_cron_expr": "0 2 * * *", + "retention_cron_tz": "America/New_York", + } + + with ( + patch(f"{MODULE}.JobExecution"), + patch(f"{MODULE}.JobSchedule") as mock_schedule, + ): + page.update_project("proj", payload) + + mock_schedule.upsert_for_retention.assert_called_once_with( + project_code="proj", + retention_days=90, + cron_expr="0 2 * * *", + cron_tz="America/New_York", + ) + mock_schedule.delete_for_retention.assert_not_called() + assert page.project.data_retention_enabled is True + assert page.project.data_retention_days == 90 + + +def test_update_project_deletes_schedule_when_retention_disabled(mock_session): + """No-op cleanup contract: disabling retention removes the schedule so the + cleanup job never fires for this project.""" + page = _make_page(data_retention_enabled=True, data_retention_days=180) + payload = { + "name": "My Project", + "use_dq_score_weights": True, + "data_retention_enabled": False, + } + + with ( + patch(f"{MODULE}.JobExecution"), + patch(f"{MODULE}.JobSchedule") as mock_schedule, + ): + page.update_project("proj", payload) + + mock_schedule.delete_for_retention.assert_called_once_with("proj") + mock_schedule.upsert_for_retention.assert_not_called() + assert page.project.data_retention_enabled is False + # When disabled the days column is nulled out (matches the migration's nullable column). + assert page.project.data_retention_days is None + + +def test_update_project_uses_default_days_when_missing(mock_session): + """Enabling retention without an explicit days value falls back to the page's + DEFAULT_RETENTION_DAYS constant (180) so the schedule is still well-formed.""" + page = _make_page(data_retention_enabled=False) + payload = { + "name": "My Project", + "use_dq_score_weights": True, + "data_retention_enabled": True, + # data_retention_days omitted + "retention_cron_expr": "0 1 * * *", + "retention_cron_tz": "UTC", + } + + with ( + patch(f"{MODULE}.JobExecution"), + patch(f"{MODULE}.JobSchedule") as mock_schedule, + ): + page.update_project("proj", payload) + + kwargs = mock_schedule.upsert_for_retention.call_args.kwargs + assert kwargs["retention_days"] == 180 From b95ccea9c4a4fdd7505ad223302c9d4b53918d62 Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Thu, 28 May 2026 13:04:15 -0400 Subject: [PATCH 41/58] feat(mcp): test definition note CRUD tools (TG-1086) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add create_test_note, update_test_note, and delete_test_note. All gated on `edit` permission; update and delete additionally require the caller to be the note's author. Extends `add_note` to return the persisted instance so the new note ID and timestamp can be surfaced without a second roundtrip. Adds `Test note ID` column to `list_test_notes` output to keep the producer→consumer chain workable. Co-Authored-By: Claude Opus 4.7 (1M context) --- testgen/common/models/test_definition.py | 15 +- testgen/mcp/server.py | 6 + testgen/mcp/tools/common.py | 25 ++- testgen/mcp/tools/test_definitions.py | 103 +++++++++- tests/unit/mcp/test_tools_common.py | 34 ++++ tests/unit/mcp/test_tools_test_definitions.py | 190 +++++++++++++++++- 6 files changed, 363 insertions(+), 10 deletions(-) diff --git a/testgen/common/models/test_definition.py b/testgen/common/models/test_definition.py index 82b7b53c..e0a71458 100644 --- a/testgen/common/models/test_definition.py +++ b/testgen/common/models/test_definition.py @@ -1,6 +1,6 @@ from collections.abc import Iterable from dataclasses import dataclass -from datetime import datetime +from datetime import UTC, datetime from enum import StrEnum from itertools import zip_longest from typing import ClassVar, Literal @@ -673,11 +673,18 @@ class TestDefinitionNote(Base): updated_at: datetime = Column(postgresql.TIMESTAMP) @classmethod - def add_note(cls, test_definition_id: str | UUID, detail: str, username: str) -> None: + def add_note(cls, test_definition_id: str | UUID, detail: str, username: str) -> "TestDefinitionNote": + """Insert a note and return the persisted instance with ``id`` and ``created_at`` populated.""" db_session = get_current_session() - db_session.execute( - insert(cls).values(test_definition_id=test_definition_id, detail=detail, created_by=username) + note = cls( + test_definition_id=test_definition_id, + detail=detail, + created_by=username, + created_at=datetime.now(UTC).replace(tzinfo=None), ) + db_session.add(note) + db_session.flush() + return note @classmethod def update_note(cls, note_id: str | UUID, detail: str) -> None: diff --git a/testgen/mcp/server.py b/testgen/mcp/server.py index 901c665b..a29e84b6 100644 --- a/testgen/mcp/server.py +++ b/testgen/mcp/server.py @@ -194,11 +194,14 @@ def build_mcp_server( from testgen.mcp.tools.test_definitions import ( bulk_update_tests, create_test, + create_test_note, + delete_test_note, get_test, list_test_notes, list_test_types, list_tests, update_test, + update_test_note, validate_custom_test, ) from testgen.mcp.tools.test_results import ( @@ -276,6 +279,9 @@ def safe_prompt(fn): safe_tool(update_test) safe_tool(validate_custom_test) safe_tool(bulk_update_tests) + safe_tool(create_test_note) + safe_tool(update_test_note) + safe_tool(delete_test_note) safe_tool(list_hygiene_issues) safe_tool(get_hygiene_issue) safe_tool(search_hygiene_issues) diff --git a/testgen/mcp/tools/common.py b/testgen/mcp/tools/common.py index 1da05a6b..ea472419 100644 --- a/testgen/mcp/tools/common.py +++ b/testgen/mcp/tools/common.py @@ -19,7 +19,7 @@ from testgen.common.models.scheduler import SCHEDULABLE_JOB_KEYS, JobSchedule from testgen.common.models.scores import ScoreCategory, ScoreDefinition from testgen.common.models.table_group import TableGroup -from testgen.common.models.test_definition import TestDefinition, TestType +from testgen.common.models.test_definition import TestDefinition, TestDefinitionNote, TestType from testgen.common.models.test_result import TestResultStatus from testgen.common.models.test_suite import TestSuite from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError @@ -574,6 +574,29 @@ def resolve_test_definition(test_definition_id: str) -> TestDefinition: return td +def resolve_test_note(test_note_id: str) -> TestDefinitionNote: + """Resolve a test note ID to the live ORM model, collapsing missing-or-inaccessible. + + Filters monitor suites and project access via the note's parent test definition. + """ + note_uuid = parse_uuid(test_note_id, "test_note_id") + perms = get_project_permissions() + query = ( + select(TestDefinitionNote) + .join(TestDefinition, TestDefinitionNote.test_definition_id == TestDefinition.id) + .join(TestSuite, TestDefinition.test_suite_id == TestSuite.id) + .where( + TestDefinitionNote.id == note_uuid, + TestSuite.is_monitor.isnot(True), + TestSuite.project_code.in_(perms.allowed_codes), + ) + ) + note = get_current_session().scalars(query).first() + if note is None: + raise MCPResourceNotAccessible("Test note", test_note_id) + return note + + def resolve_schedule(schedule_id: str) -> JobSchedule: """Resolve a user-managed schedule ID, collapsing missing-or-inaccessible into one error path.""" sched_uuid = parse_uuid(schedule_id, "schedule_id") diff --git a/testgen/mcp/tools/test_definitions.py b/testgen/mcp/tools/test_definitions.py index 8a87e677..9d05f56a 100644 --- a/testgen/mcp/tools/test_definitions.py +++ b/testgen/mcp/tools/test_definitions.py @@ -18,7 +18,7 @@ ) from testgen.common.models.test_result import TestResult from testgen.mcp.exceptions import MCPUserError -from testgen.mcp.permissions import get_project_permissions, mcp_permission +from testgen.mcp.permissions import get_authorized_mcp_user, get_project_permissions, mcp_permission from testgen.mcp.tools.common import ( DocGroup, format_page_footer, @@ -27,6 +27,7 @@ parse_quality_dimension, parse_uuid, resolve_test_definition, + resolve_test_note, resolve_test_suite, resolve_test_type, validate_limit, @@ -267,15 +268,111 @@ def list_test_notes(test_definition_id: str) -> str: doc.text(f"{len(notes)} note(s).") doc.table( - headers=["Date", "Author", "Note", "Updated"], + headers=["Test note ID", "Date", "Author", "Note", "Updated"], rows=[ - [n["created_at"], n["created_by"], n["detail"], n["updated_at"]] + [n["id"], n["created_at"], n["created_by"], n["detail"], n["updated_at"]] for n in notes ], + code=[0], ) return doc.render() +def _validate_note_body(body: str) -> None: + if not isinstance(body, str) or not body.strip(): + raise MCPUserError("`body` cannot be empty or whitespace-only.") + + +def _note_parent_label(summary: TestDefinitionSummary) -> str: + where = f"`{summary.column_name}` in `{summary.table_name}`" if summary.column_name else f"`{summary.table_name}`" + return f"{summary.display_name} on {where}" + + +@with_database_session +@mcp_permission("edit") +def create_test_note(test_definition_id: str, body: str) -> str: + """Attach a note to a test definition. + + Args: + test_definition_id: UUID of the test definition, e.g. from ``list_tests``. + body: Note body (free-text). Empty or whitespace-only is rejected. + """ + _validate_note_body(body) + td = resolve_test_definition(test_definition_id) + username = get_authorized_mcp_user().username + + note = TestDefinitionNote.add_note(td.id, body, username) + + perms = get_project_permissions() + summary = TestDefinition.get_for_project(td.id, perms.allowed_codes) + + doc = MdDoc() + doc.text(f"**Note added** to {_note_parent_label(summary)}.") + doc.field("Test note ID", note.id, code=True) + doc.field("Author", username) + doc.field("Date", note.created_at) + doc.field("Note", note.detail) + return doc.render() + + +@with_database_session +@mcp_permission("edit") +def update_test_note(test_note_id: str, body: str) -> str: + """Replace the body of a test note. Only the note's author can update it. + + Args: + test_note_id: UUID of the test note, e.g. from ``list_test_notes`` or ``create_test_note``. + body: New note body (free-text). Empty or whitespace-only is rejected. + """ + _validate_note_body(body) + note = resolve_test_note(test_note_id) + username = get_authorized_mcp_user().username + if note.created_by != username: + raise MCPUserError("You can only edit notes you authored.") + + before_body = note.detail + TestDefinitionNote.update_note(note.id, body) + + perms = get_project_permissions() + summary = TestDefinition.get_for_project(note.test_definition_id, perms.allowed_codes) + + doc = MdDoc() + doc.text(f"**Note updated** on {_note_parent_label(summary)}.") + doc.table( + headers=["Field", "Before", "After"], + rows=[["Note", before_body, body]], + ) + return doc.render() + + +@with_database_session +@mcp_permission("edit") +def delete_test_note(test_note_id: str) -> str: + """Delete a test note. Only the note's author can delete it. + + Args: + test_note_id: UUID of the test note, e.g. from ``list_test_notes``. + """ + note = resolve_test_note(test_note_id) + username = get_authorized_mcp_user().username + if note.created_by != username: + raise MCPUserError("You can only delete notes you authored.") + + author = note.created_by + created_at = note.created_at + td_id = note.test_definition_id + TestDefinitionNote.delete_note(note.id) + + perms = get_project_permissions() + summary = TestDefinition.get_for_project(td_id, perms.allowed_codes) + + doc = MdDoc() + doc.text(f"**Note deleted** from {_note_parent_label(summary)}.") + doc.field("Author", author) + doc.field("Date", created_at) + return doc.render() + + def _append_parameters_section(doc: MdDoc, td: TestDefinitionSummary) -> None: """Build the editable parameters table from test type metadata. diff --git a/tests/unit/mcp/test_tools_common.py b/tests/unit/mcp/test_tools_common.py index aeaffca9..3b5e7f32 100644 --- a/tests/unit/mcp/test_tools_common.py +++ b/tests/unit/mcp/test_tools_common.py @@ -31,6 +31,7 @@ parse_uuid, resolve_issue_type, resolve_profiling_run, + resolve_test_note, validate_limit, validate_page, ) @@ -345,6 +346,39 @@ def test_resolve_profiling_run_invalid_uuid(): resolve_profiling_run("not-a-uuid") +# --- resolve_test_note --- + + +@patch("testgen.mcp.tools.common.get_project_permissions") +@patch("testgen.mcp.tools.common.get_current_session") +def test_resolve_test_note_happy_path(mock_get_session, mock_get_perms): + note = MagicMock() + session = MagicMock() + session.scalars.return_value.first.return_value = note + mock_get_session.return_value = session + mock_get_perms.return_value = _mock_perms() + + assert resolve_test_note(str(uuid4())) is note + + +@patch("testgen.mcp.tools.common.get_project_permissions") +@patch("testgen.mcp.tools.common.get_current_session") +def test_resolve_test_note_missing_or_inaccessible(mock_get_session, mock_get_perms): + """Missing note, monitor-suite parent, and forbidden project all collapse to one error.""" + session = MagicMock() + session.scalars.return_value.first.return_value = None + mock_get_session.return_value = session + mock_get_perms.return_value = _mock_perms() + + with pytest.raises(MCPResourceNotAccessible, match=r"Test note .* not found or not accessible"): + resolve_test_note(str(uuid4())) + + +def test_resolve_test_note_invalid_uuid(): + with pytest.raises(MCPUserError, match="Invalid test_note_id"): + resolve_test_note("not-a-uuid") + + # --- parse_pii_category --- diff --git a/tests/unit/mcp/test_tools_test_definitions.py b/tests/unit/mcp/test_tools_test_definitions.py index 90213043..46488143 100644 --- a/tests/unit/mcp/test_tools_test_definitions.py +++ b/tests/unit/mcp/test_tools_test_definitions.py @@ -415,9 +415,11 @@ def test_list_test_notes_basic(mock_td, mock_notes, db_session_mock): td.column_name = "name" mock_td.get_for_project.return_value = td + note_id_1 = str(uuid4()) + note_id_2 = str(uuid4()) mock_notes.get_notes.return_value = [ - {"detail": "Threshold looks wrong", "created_by": "alice", "created_at": "2026-04-01T10:00:00", "updated_at": None}, - {"detail": "Confirmed with team", "created_by": "bob", "created_at": "2026-04-02T14:30:00", "updated_at": "2026-04-03T09:00:00"}, + {"id": note_id_1, "detail": "Threshold looks wrong", "created_by": "alice", "created_at": "2026-04-01T10:00:00", "updated_at": None}, + {"id": note_id_2, "detail": "Confirmed with team", "created_by": "bob", "created_at": "2026-04-02T14:30:00", "updated_at": "2026-04-03T09:00:00"}, ] from testgen.mcp.tools.test_definitions import list_test_notes @@ -432,6 +434,9 @@ def test_list_test_notes_basic(mock_td, mock_notes, db_session_mock): assert "alice" in result assert "2026-04-01 10:00" in result assert "2026-04-03 09:00" in result + assert "Test note ID" in result + assert note_id_1 in result + assert note_id_2 in result @patch("testgen.mcp.tools.test_definitions.TestDefinitionNote") @@ -468,6 +473,187 @@ def test_list_test_notes_invalid_uuid(db_session_mock): list_test_notes("garbage") +# -- create_test_note --------------------------------------------------------- + + +def _make_note_summary(): + """Minimal TestDefinitionSummary mock for note-tool rendering.""" + summary = MagicMock() + summary.display_name = "Alpha Truncation" + summary.table_name = "orders" + summary.column_name = "email" + return summary + + +@patch("testgen.mcp.tools.test_definitions.TestDefinitionNote") +@patch("testgen.mcp.tools.test_definitions.TestDefinition") +@patch("testgen.mcp.tools.test_definitions.resolve_test_definition") +def test_create_test_note_happy_path( + mock_resolve_td, mock_td, mock_note_model, mcp_user, db_session_mock, +): + mcp_user.username = "test_user" + td = MagicMock(id=uuid4()) + mock_resolve_td.return_value = td + + note_instance = MagicMock( + id=uuid4(), + detail="Threshold widened — confirmed with team", + created_at="2026-05-27T10:00:00", + ) + mock_note_model.add_note.return_value = note_instance + mock_td.get_for_project.return_value = _make_note_summary() + + from testgen.mcp.tools.test_definitions import create_test_note + + result = create_test_note(str(td.id), "Threshold widened — confirmed with team") + + assert "Note added" in result + assert "Alpha Truncation" in result + assert "`email`" in result + assert "`orders`" in result + assert "test_user" in result + assert str(note_instance.id) in result + mock_note_model.add_note.assert_called_once_with(td.id, "Threshold widened — confirmed with team", "test_user") + + +@patch("testgen.mcp.tools.test_definitions.resolve_test_definition") +def test_create_test_note_rejects_empty_body(mock_resolve_td, db_session_mock): + from testgen.mcp.tools.test_definitions import create_test_note + + with pytest.raises(MCPUserError, match="cannot be empty"): + create_test_note(str(uuid4()), "") + with pytest.raises(MCPUserError, match="cannot be empty"): + create_test_note(str(uuid4()), " \n\t ") + + mock_resolve_td.assert_not_called() + + +def test_create_test_note_invalid_uuid(db_session_mock): + from testgen.mcp.tools.test_definitions import create_test_note + + with pytest.raises(MCPUserError, match="not a valid UUID"): + create_test_note("garbage", "valid detail") + + +# -- update_test_note --------------------------------------------------------- + + +@patch("testgen.mcp.tools.test_definitions.TestDefinitionNote") +@patch("testgen.mcp.tools.test_definitions.TestDefinition") +@patch("testgen.mcp.tools.test_definitions.resolve_test_note") +def test_update_test_note_happy_path( + mock_resolve_note, mock_td, mock_note_model, mcp_user, db_session_mock, +): + mcp_user.username = "test_user" + note = MagicMock( + id=uuid4(), + test_definition_id=uuid4(), + created_by="test_user", + detail="original body", + ) + mock_resolve_note.return_value = note + mock_td.get_for_project.return_value = _make_note_summary() + + from testgen.mcp.tools.test_definitions import update_test_note + + result = update_test_note(str(note.id), "rewritten body") + + assert "Note updated" in result + assert "Alpha Truncation" in result + assert "original body" in result + assert "rewritten body" in result + mock_note_model.update_note.assert_called_once_with(note.id, "rewritten body") + + +@patch("testgen.mcp.tools.test_definitions.TestDefinitionNote") +@patch("testgen.mcp.tools.test_definitions.resolve_test_note") +def test_update_test_note_non_author_rejected( + mock_resolve_note, mock_note_model, mcp_user, db_session_mock, +): + mcp_user.username = "test_user" + note = MagicMock(created_by="someone_else") + mock_resolve_note.return_value = note + + from testgen.mcp.tools.test_definitions import update_test_note + + with pytest.raises(MCPUserError, match="You can only edit notes you authored"): + update_test_note(str(uuid4()), "new body") + + mock_note_model.update_note.assert_not_called() + + +@patch("testgen.mcp.tools.test_definitions.resolve_test_note") +def test_update_test_note_rejects_empty_body(mock_resolve_note, db_session_mock): + from testgen.mcp.tools.test_definitions import update_test_note + + with pytest.raises(MCPUserError, match="cannot be empty"): + update_test_note(str(uuid4()), "") + with pytest.raises(MCPUserError, match="cannot be empty"): + update_test_note(str(uuid4()), " ") + + mock_resolve_note.assert_not_called() + + +def test_update_test_note_invalid_uuid(db_session_mock): + from testgen.mcp.tools.test_definitions import update_test_note + + with pytest.raises(MCPUserError, match="not a valid UUID"): + update_test_note("garbage", "valid detail") + + +# -- delete_test_note --------------------------------------------------------- + + +@patch("testgen.mcp.tools.test_definitions.TestDefinitionNote") +@patch("testgen.mcp.tools.test_definitions.TestDefinition") +@patch("testgen.mcp.tools.test_definitions.resolve_test_note") +def test_delete_test_note_happy_path( + mock_resolve_note, mock_td, mock_note_model, mcp_user, db_session_mock, +): + mcp_user.username = "test_user" + note = MagicMock( + id=uuid4(), + test_definition_id=uuid4(), + created_by="test_user", + created_at="2026-05-27T10:00:00", + ) + mock_resolve_note.return_value = note + mock_td.get_for_project.return_value = _make_note_summary() + + from testgen.mcp.tools.test_definitions import delete_test_note + + result = delete_test_note(str(note.id)) + + assert "Note deleted" in result + assert "Alpha Truncation" in result + assert "test_user" in result + mock_note_model.delete_note.assert_called_once_with(note.id) + + +@patch("testgen.mcp.tools.test_definitions.TestDefinitionNote") +@patch("testgen.mcp.tools.test_definitions.resolve_test_note") +def test_delete_test_note_non_author_rejected( + mock_resolve_note, mock_note_model, mcp_user, db_session_mock, +): + mcp_user.username = "test_user" + note = MagicMock(created_by="someone_else") + mock_resolve_note.return_value = note + + from testgen.mcp.tools.test_definitions import delete_test_note + + with pytest.raises(MCPUserError, match="You can only delete notes you authored"): + delete_test_note(str(uuid4())) + + mock_note_model.delete_note.assert_not_called() + + +def test_delete_test_note_invalid_uuid(db_session_mock): + from testgen.mcp.tools.test_definitions import delete_test_note + + with pytest.raises(MCPUserError, match="not a valid UUID"): + delete_test_note("garbage") + + # -- list_test_types ---------------------------------------------------------- From 0b651a2d13646d7edb77c4f5be5638f6ee6f42b3 Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Thu, 28 May 2026 15:25:41 -0400 Subject: [PATCH 42/58] misc: remove noisy streamlit logs in debug mode --- testgen/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testgen/__main__.py b/testgen/__main__.py index a1ea67a1..ceb396c5 100644 --- a/testgen/__main__.py +++ b/testgen/__main__.py @@ -982,7 +982,7 @@ def init_ui(): "run", app_file, "--browser.gatherUsageStats=false", - f"--logger.level={'debug' if settings.IS_DEBUG else 'error'}", + "--logger.level=error", "--client.showErrorDetails=none", "--client.toolbarMode=minimal", "--server.enableStaticServing=true", From a36ee7d0a7f29f1e9a5f38cd79cd78b0fa298ccc Mon Sep 17 00:00:00 2001 From: Ricardo Boni Date: Fri, 29 May 2026 12:27:40 -0400 Subject: [PATCH 43/58] refactor(mcp): apply TG-1086 review feedback - create_test_note: escape note body before rendering via doc.field (MdDoc.escape) - TestDefinitionNote.update_note: use datetime.now(UTC) to match add_note and the codebase clock convention Co-Authored-By: Claude Opus 4.8 (1M context) --- testgen/common/models/test_definition.py | 4 +++- testgen/mcp/tools/test_definitions.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/testgen/common/models/test_definition.py b/testgen/common/models/test_definition.py index e0a71458..ec9ddde0 100644 --- a/testgen/common/models/test_definition.py +++ b/testgen/common/models/test_definition.py @@ -689,7 +689,9 @@ def add_note(cls, test_definition_id: str | UUID, detail: str, username: str) -> @classmethod def update_note(cls, note_id: str | UUID, detail: str) -> None: db_session = get_current_session() - db_session.execute(update(cls).where(cls.id == note_id).values(detail=detail, updated_at=func.now())) + db_session.execute( + update(cls).where(cls.id == note_id).values(detail=detail, updated_at=datetime.now(UTC).replace(tzinfo=None)) + ) @classmethod def delete_note(cls, note_id: str | UUID) -> None: diff --git a/testgen/mcp/tools/test_definitions.py b/testgen/mcp/tools/test_definitions.py index 9d05f56a..856e2e54 100644 --- a/testgen/mcp/tools/test_definitions.py +++ b/testgen/mcp/tools/test_definitions.py @@ -311,7 +311,7 @@ def create_test_note(test_definition_id: str, body: str) -> str: doc.field("Test note ID", note.id, code=True) doc.field("Author", username) doc.field("Date", note.created_at) - doc.field("Note", note.detail) + doc.field("Note", MdDoc.escape(note.detail)) return doc.render() From e9b3c0e357c62054c38675bfde300cf8d92fa349 Mon Sep 17 00:00:00 2001 From: Astor Date: Fri, 29 May 2026 01:37:30 -0300 Subject: [PATCH 44/58] feat: add feedback popup and help item --- pyproject.toml | 4 + testgen/commands/run_launch_db_config.py | 2 + testgen/common/mixpanel_service.py | 11 + testgen/common/models/user.py | 42 +++- testgen/settings.py | 8 + .../030_initialize_new_schema_structure.sql | 3 +- .../040_populate_new_schema_project.sql | 5 +- .../dbupgrade/0192_incremental_upgrade.sql | 4 + .../frontend/js/pages/feedback_widget.js | 225 ++++++++++++++++++ testgen/ui/components/widgets/__init__.py | 7 + testgen/ui/components/widgets/page.py | 43 ++++ testgen/ui/navigation/router.py | 25 ++ testgen/ui/session.py | 2 + testgen/ui/static/css/style.css | 4 + testgen/ui/static/js/components/help_menu.js | 8 + 15 files changed, 387 insertions(+), 6 deletions(-) create mode 100644 testgen/template/dbupgrade/0192_incremental_upgrade.sql create mode 100644 testgen/ui/components/frontend/js/pages/feedback_widget.js diff --git a/pyproject.toml b/pyproject.toml index 49262388..29be615c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -401,3 +401,7 @@ asset_dir = "ui/components/frontend/js" [[tool.streamlit.component.components]] name = "sidebar" asset_dir = "ui/components/frontend/js" + +[[tool.streamlit.component.components]] +name = "feedback_widget" +asset_dir = "ui/components/frontend/js" diff --git a/testgen/commands/run_launch_db_config.py b/testgen/commands/run_launch_db_config.py index f7b69bcb..824d200d 100644 --- a/testgen/commands/run_launch_db_config.py +++ b/testgen/commands/run_launch_db_config.py @@ -7,6 +7,7 @@ from testgen.common.database.database_service import get_queries_for_command from testgen.common.encrypt import EncryptText, encrypt_ui_password from testgen.common.models import with_database_session +from testgen.common.models.user import initial_feedback_popup_seed from testgen.common.read_file import get_template_files from testgen.common.read_yaml_metadata_records import import_metadata_records_from_yaml from testgen.common.standalone_postgres import EMBEDDED_HOST_SENTINEL, is_standalone_mode @@ -40,6 +41,7 @@ def _get_params_mapping() -> dict: "UI_USER_USERNAME": settings.USERNAME, "UI_USER_EMAIL": "", "UI_USER_ENCRYPTED_PASSWORD": ui_user_encrypted_password, + "LAST_FEEDBACK_POPUP_SEED": initial_feedback_popup_seed(), "SCHEMA_NAME": get_tg_schema(), "PROJECT_CODE": settings.PROJECT_KEY, "CONNECTION_ID": 1, diff --git a/testgen/common/mixpanel_service.py b/testgen/common/mixpanel_service.py index dba5c74f..1af2ae4c 100644 --- a/testgen/common/mixpanel_service.py +++ b/testgen/common/mixpanel_service.py @@ -54,6 +54,17 @@ def _hash_value(self, value: bytes | str, digest_size: int = 8) -> str: @safe_method def send_event(self, event_name, include_usage=False, **properties): + self._track(event_name, include_usage=include_usage, **properties) + + def send_feedback(self, **properties): + # User-submitted feedback is content the user explicitly chose to share + # so it is not gated by the TG_ANALYTICS opt-out. + try: + self._track("feedback", **properties) + except Exception: + LOG.exception("Error sending feedback") + + def _track(self, event_name, include_usage=False, **properties): properties.setdefault("instance_id", self.instance_id) properties.setdefault("edition", settings.DOCKER_HUB_REPOSITORY) properties.setdefault("version", settings.VERSION) diff --git a/testgen/common/models/user.py b/testgen/common/models/user.py index c96f271f..09022ff1 100644 --- a/testgen/common/models/user.py +++ b/testgen/common/models/user.py @@ -1,8 +1,9 @@ -from datetime import UTC, datetime -from typing import Self +from datetime import UTC, datetime, timedelta +from enum import StrEnum +from typing import Any, Self from uuid import UUID, uuid4 -from sqlalchemy import Boolean, Column, String, asc, func, select, update +from sqlalchemy import Boolean, Column, String, asc, func, select, text, update from sqlalchemy.dialects import postgresql from testgen.common.models import get_current_session @@ -11,6 +12,27 @@ from testgen.common.models.project_membership import RoleType +class PreferenceKey(StrEnum): + """Keys allowed in the User.preferences JSONB column.""" + + LAST_FEEDBACK_POPUP = "last_feedback_popup" + + +# Feedback popup cadence. The popup recurs every FEEDBACK_POPUP_INTERVAL. A new user's first +# popup is delayed by FEEDBACK_POPUP_INITIAL_DELAY (rather than showing on first login, when they +# have nothing to give feedback on yet) by seeding last_feedback_popup in the past at creation. +FEEDBACK_POPUP_INTERVAL = timedelta(days=30) +FEEDBACK_POPUP_INITIAL_DELAY = timedelta(days=1) + + +def initial_feedback_popup_seed() -> str: + return (datetime.now(UTC) - (FEEDBACK_POPUP_INTERVAL - FEEDBACK_POPUP_INITIAL_DELAY)).isoformat() + + +def default_user_preferences() -> dict: + return {PreferenceKey.LAST_FEEDBACK_POPUP: initial_feedback_popup_seed()} + + class User(Entity): __tablename__ = "auth_users" @@ -21,6 +43,9 @@ class User(Entity): password: str = Column(String) is_global_admin: bool = Column(Boolean, nullable=False, default=False) latest_login: datetime = Column(postgresql.TIMESTAMP) + preferences: dict = Column( + postgresql.JSONB, nullable=False, default=default_user_preferences, server_default=text("'{}'") + ) _get_by = "username" _default_order_by = (asc(func.lower(username)),) @@ -40,6 +65,17 @@ def save(self, update_latest_login: bool = False) -> None: self.latest_login = datetime.now(UTC) super().save() + def get_preference(self, key: PreferenceKey, default: Any = None) -> Any: + return self.preferences.get(key, default) + + def set_preference(self, key: PreferenceKey, value: Any) -> None: + self.preferences[key] = value + self.update_preferences() + + def update_preferences(self) -> None: + query = update(User).where(User.id == self.id).values(preferences=self.preferences) + get_current_session().execute(query) + @classmethod def get(cls, identifier: str) -> Self | None: query = select(cls).where(func.lower(User.username) == func.lower(identifier)) diff --git a/testgen/settings.py b/testgen/settings.py index 339592cf..69e922dc 100644 --- a/testgen/settings.py +++ b/testgen/settings.py @@ -512,6 +512,14 @@ def _ssl_files_present() -> bool: Disables sending usage data when set to any value except "true" and "yes". Defaults to "yes" """ +DISABLE_FEEDBACK_POPUP: bool = getenv("TG_DISABLE_FEEDBACK_POPUP", "no").lower() in ("yes", "true") +""" +When set to "yes" or "true", suppresses the periodic feedback popup entirely. + +from env variable: `TG_DISABLE_FEEDBACK_POPUP` +defaults to: `no` +""" + JOB_POLL_INTERVAL: int = int(getenv("TG_JOB_POLL_INTERVAL", "5")) """ Seconds between polls for pending job executions. diff --git a/testgen/template/dbsetup/030_initialize_new_schema_structure.sql b/testgen/template/dbsetup/030_initialize_new_schema_structure.sql index 75dd441f..66e4db8b 100644 --- a/testgen/template/dbsetup/030_initialize_new_schema_structure.sql +++ b/testgen/template/dbsetup/030_initialize_new_schema_structure.sql @@ -713,7 +713,8 @@ CREATE TABLE auth_users ( name VARCHAR(256), password VARCHAR(120), is_global_admin BOOLEAN NOT NULL DEFAULT FALSE, - latest_login TIMESTAMP + latest_login TIMESTAMP, + preferences JSONB NOT NULL DEFAULT '{}' ); ALTER TABLE auth_users diff --git a/testgen/template/dbsetup/040_populate_new_schema_project.sql b/testgen/template/dbsetup/040_populate_new_schema_project.sql index 0e3e5b7c..204b84c2 100644 --- a/testgen/template/dbsetup/040_populate_new_schema_project.sql +++ b/testgen/template/dbsetup/040_populate_new_schema_project.sql @@ -22,13 +22,14 @@ SELECT gen_random_uuid(), WITH inserted_user AS ( INSERT INTO auth_users - (username, email, name, password, is_global_admin) + (username, email, name, password, is_global_admin, preferences) SELECT '{UI_USER_USERNAME}' as username, '{UI_USER_EMAIL}' as email, '{UI_USER_NAME}' as name, '{UI_USER_ENCRYPTED_PASSWORD}' as password, - true as is_global_admin + true as is_global_admin, + jsonb_build_object('last_feedback_popup', '{LAST_FEEDBACK_POPUP_SEED}') as preferences RETURNING id ) INSERT INTO project_memberships diff --git a/testgen/template/dbupgrade/0192_incremental_upgrade.sql b/testgen/template/dbupgrade/0192_incremental_upgrade.sql new file mode 100644 index 00000000..ce8c1158 --- /dev/null +++ b/testgen/template/dbupgrade/0192_incremental_upgrade.sql @@ -0,0 +1,4 @@ +SET SEARCH_PATH TO {SCHEMA_NAME}; + +ALTER TABLE auth_users + ADD COLUMN IF NOT EXISTS preferences JSONB NOT NULL DEFAULT '{}'; diff --git a/testgen/ui/components/frontend/js/pages/feedback_widget.js b/testgen/ui/components/frontend/js/pages/feedback_widget.js new file mode 100644 index 00000000..9f804b7b --- /dev/null +++ b/testgen/ui/components/frontend/js/pages/feedback_widget.js @@ -0,0 +1,225 @@ +import van from '/app/static/js/van.min.js'; +import { createEmitter, isEqual, loadStylesheet } from '/app/static/js/utils.js'; +import { Button } from '/app/static/js/components/button.js'; +import { Icon } from '/app/static/js/components/icon.js'; +import { Input } from '/app/static/js/components/input.js'; +import { Textarea } from '/app/static/js/components/textarea.js'; +const { div, span } = van.tags; + +const RATINGS = [ + { value: 1, emoji: '\u{1F620}', label: 'Frustrated' }, // 😠 + { value: 2, emoji: '\u{1F615}', label: 'Dissatisfied' }, // 😕 + { value: 3, emoji: '\u{1F610}', label: 'Neutral' }, // 😐 + { value: 4, emoji: '\u{1F642}', label: 'Satisfied' }, // 🙂 + { value: 5, emoji: '\u{1F929}', label: 'Love it!' }, // 🤩 +]; + +const FeedbackWidget = (props) => { + loadStylesheet('feedback-widget', stylesheet); + + const selectedRating = van.state(0); + const comment = van.state(''); + const email = van.state(''); + const expanded = van.state(false); + const showSuccess = van.state(false); + const submitting = van.state(false); + + const handleClose = () => { + props.emit('FeedbackDismissed', {}); + }; + + const handleSubmit = () => { + if (selectedRating.val === 0 || submitting.val) return; + submitting.val = true; + props.emit('FeedbackSubmitted', { + payload: { + rating: selectedRating.val, + comment: comment.val, + email: email.val, + }, + }); + showSuccess.val = true; + setTimeout(() => { + submitting.val = false; + props.emit('FeedbackDismissed', {}); + }, 2000); + }; + + return div( + { class: 'feedback-widget' }, + + () => !showSuccess.val + ? div( + { class: 'flex-column' }, + div( + { class: 'flex-row fx-justify-space-between p-4 pb-0' }, + div( + { class: 'flex-column fx-gap-1' }, + div({ class: 'text-bold' }, "How's your experience?"), + div({ class: 'text-caption' }, 'Your feedback helps us improve TestGen'), + ), + Button({ type: 'icon', color: 'basic', icon: 'close', onclick: handleClose }), + ), + div( + { class: 'flex-row fx-justify-space-between p-4' }, + ...RATINGS.map(rating => + div( + { + class: () => `rating-option ${selectedRating.val === rating.value ? 'selected' : ''}`, + onclick: () => { selectedRating.val = rating.value; }, + }, + span({ class: 'rating-emoji' }, rating.emoji), + span({ class: 'text-caption' }, rating.label), + ) + ), + ), + div( + { class: 'p-4 pt-0 flex-column fx-gap-3' }, + div( + { class: 'expander-row flex-row fx-justify-space-between clickable', onclick: () => { expanded.val = !expanded.val; } }, + span({ class: 'text-caption' }, 'Add a comment (optional)'), + Icon({ size: 18, classes: 'text-secondary' }, () => expanded.val ? 'keyboard_arrow_up' : 'keyboard_arrow_down'), + ), + div( + { class: 'flex-column fx-gap-3', style: () => expanded.val ? '' : 'display:none' }, + Textarea({ + label: 'Comment', + placeholder: "What's on your mind?", + value: comment, + onChange: (v) => { comment.val = v; }, + height: 64, + }), + Input({ + label: 'Email (optional)', + placeholder: 'you@company.com', + type: 'email', + value: email, + onChange: (v) => { email.val = v; }, + }), + ), + div( + { class: 'flex-row fx-justify-flex-end' }, + Button({ + type: 'flat', + color: 'primary', + label: 'Submit', + icon: 'send', + width: 'auto', + disabled: () => selectedRating.val === 0 || submitting.val, + onclick: handleSubmit, + }), + ), + ), + ) + : div( + { class: 'flex-column fx-align-flex-center p-5 feedback-success' }, + Icon({ size: 48, classes: 'text-green mb-3' }, 'check_circle'), + div({ class: 'text-bold mb-1' }, 'Thanks for your feedback!'), + div({ class: 'text-caption' }, 'We appreciate you taking the time.'), + ), + ); +}; + +const stylesheet = new CSSStyleSheet(); +stylesheet.replace(` +.feedback-widget { + position: fixed; + bottom: 24px; + right: 24px; + width: 340px; + font-family: 'Roboto', 'Helvetica Neue', sans-serif; + font-size: 14px; + color: var(--primary-text-color); + background: var(--portal-background); + border: 1px solid var(--border-color); + border-radius: 12px; + box-shadow: var(--portal-box-shadow); + overflow: hidden; + transition: opacity .25s, transform .25s; + transform-origin: bottom right; + z-index: 9999; +} + +.feedback-widget.hidden { + opacity: 0; + transform: scale(.95) translateY(8px); + pointer-events: none; +} + +.rating-option { + flex: 1; + display: flex; + flex-direction: column; + align-items: center; + gap: 4px; + padding: 8px 4px; + border-radius: 8px; + cursor: pointer; + transition: .2s; + border: 2px solid transparent; +} + +.rating-option:hover { + background: var(--select-hover-background); +} + +.rating-option.selected { + background: var(--select-hover-background); + border-color: var(--primary-color); +} + +.rating-emoji { + font-size: 28px; + line-height: 1; + filter: saturate(.8); + transition: .15s; +} + +.rating-option:hover .rating-emoji, +.rating-option.selected .rating-emoji { + transform: scale(1.15); + filter: saturate(1); +} + +.rating-option.selected .text-caption { + color: var(--primary-color); + font-weight: 500; +} + +.expander-row { + padding: 4px; + border-radius: 6px; +} + +.expander-row:hover { + background: var(--select-hover-background); +} + +.feedback-success { + text-align: center; + min-height: 160px; +} +`); + +export default (component) => { + const { data, setTriggerValue, parentElement } = component; + + let componentState = parentElement.state; + if (componentState === undefined) { + componentState = {}; + for (const [key, value] of Object.entries(data)) { + componentState[key] = van.state(value); + } + parentElement.state = componentState; + componentState.emit = createEmitter(setTriggerValue); + van.add(parentElement, FeedbackWidget(componentState)); + } else { + for (const [key, value] of Object.entries(data)) { + if (!isEqual(componentState[key].val, value)) { + componentState[key].val = value; + } + } + } + + return () => { parentElement.state = null; }; +}; diff --git a/testgen/ui/components/widgets/__init__.py b/testgen/ui/components/widgets/__init__.py index 6b3f23a9..6fdf28ce 100644 --- a/testgen/ui/components/widgets/__init__.py +++ b/testgen/ui/components/widgets/__init__.py @@ -146,3 +146,10 @@ js="pages/sidebar.js", isolate_styles=False, )) + +feedback_widget = component_v2_wrapped(components_v2.component( + name="dataops-testgen.feedback_widget", + js="pages/feedback_widget.js", + isolate_styles=False, +)) + diff --git a/testgen/ui/components/widgets/page.py b/testgen/ui/components/widgets/page.py index 737f4a55..9775306f 100644 --- a/testgen/ui/components/widgets/page.py +++ b/testgen/ui/components/widgets/page.py @@ -9,6 +9,7 @@ import testgen.common.logs as logs from testgen import settings from testgen.common import version_service +from testgen.common.mixpanel_service import MixpanelService from testgen.ui.services.rerun_service import safe_rerun from testgen.ui.session import session @@ -44,6 +45,9 @@ def page_header( st.html('
') + # Feedback widget (bottom-right) + render_feedback_widget() + # Render app logs dialog widget (outside the header container) logs_data = st.session_state.get(APP_LOGS_DIALOG_KEY) if logs_data: @@ -114,6 +118,10 @@ def open_app_logs(): close_help() st.session_state[APP_LOGS_DIALOG_KEY] = _read_log_data() + def open_feedback(): + close_help() + session.show_feedback_popup = True + with help_container.container(): flex_row_end() with st.popover("Help"): @@ -127,9 +135,11 @@ def open_app_logs(): "version": version.__dict__, "permissions": { "can_edit": session.auth.user_has_permission("edit"), + "is_logged_in": session.auth.is_logged_in, }, }, on_AppLogsClicked_change=lambda _: open_app_logs(), + on_FeedbackClicked_change=lambda _: open_feedback(), on_ExternalLinkClicked_change=lambda _: close_help(rerun=True), ) @@ -175,3 +185,36 @@ def _apply_html(html: str, container: DeltaGenerator | None = None): container.html(html) else: st.html(html) + + +def render_feedback_widget(): + """Render the feedback popup widget in the bottom-right corner. + + Visibility is driven by session.show_feedback_popup: + - set by router on session start (30-day eligibility gate) + - set when the user manually clicks "Give Feedback" + + Feedback submissions are sent to MixPanel. + """ + if not bool(session.show_feedback_popup): + return + + def on_dismissed(_): + session.show_feedback_popup = False + + def on_submitted(payload): + if payload: + MixpanelService().send_feedback( + rating=int(payload.get("rating", 0)), + comment=payload.get("comment") or None, + email=payload.get("email") or None, + ) + + from testgen.ui.components.widgets import feedback_widget + feedback_widget( + key="feedback_widget", + data={}, + on_FeedbackDismissed_change=on_dismissed, + on_FeedbackSubmitted_change=on_submitted, + ) + diff --git a/testgen/ui/navigation/router.py b/testgen/ui/navigation/router.py index c53c8759..c86636cd 100644 --- a/testgen/ui/navigation/router.py +++ b/testgen/ui/navigation/router.py @@ -2,11 +2,14 @@ import logging import time +from datetime import UTC, datetime import streamlit as st import testgen.ui.navigation.page +from testgen import settings from testgen.common.mixpanel_service import MixpanelService +from testgen.common.models.user import FEEDBACK_POPUP_INTERVAL, PreferenceKey from testgen.ui.session import session from testgen.utils.singleton import Singleton @@ -31,6 +34,25 @@ def _init_session(self, url: str): source = st.query_params.pop("source", None) MixpanelService().send_event(f"nav-{url}", page_load=True, source=source) + def _evaluate_feedback_popup(self) -> None: + session.show_feedback_popup = False + try: + if settings.DISABLE_FEEDBACK_POPUP or not (user := session.auth.user): + return + + if (last_popup_str := user.get_preference(PreferenceKey.LAST_FEEDBACK_POPUP)): + try: + last_popup_dt = datetime.fromisoformat(last_popup_str) + if datetime.now(UTC) - last_popup_dt < FEEDBACK_POPUP_INTERVAL: + return + except (ValueError, TypeError): + pass # Corrupted value — treat as no prior popup + + user.set_preference(PreferenceKey.LAST_FEEDBACK_POPUP, datetime.now(UTC).isoformat()) + session.show_feedback_popup = True + except Exception: + LOG.exception("Error evaluating feedback popup eligibility") + def run(self) -> None: streamlit_pages = [route.streamlit_page for route in self._routes.values()] @@ -63,6 +85,9 @@ def run(self) -> None: st.query_params.from_dict(session.page_args_pending_router) session.page_args_pending_router = None + if session.show_feedback_popup is None and session.auth.is_logged_in: + self._evaluate_feedback_popup() + session.current_page = current_page.url_path current_page.run() else: diff --git a/testgen/ui/session.py b/testgen/ui/session.py index 9f50ed33..5b8389f5 100644 --- a/testgen/ui/session.py +++ b/testgen/ui/session.py @@ -35,6 +35,8 @@ class TestgenSession(Singleton): add_project: bool version: Version | None + show_feedback_popup: bool | None + testgen_event_id: ClassVar[dict[str, str]] = {} sidebar_event_id: str | None link_event_id: str | None diff --git a/testgen/ui/static/css/style.css b/testgen/ui/static/css/style.css index 01dee345..66b56aa5 100644 --- a/testgen/ui/static/css/style.css +++ b/testgen/ui/static/css/style.css @@ -422,6 +422,10 @@ Use as testgen.text("text", "extra_styles") */ div[data-testid="stPopoverBody"]:has(i.tg-header--help-wrapper) { padding: 0; } + +.st-key-feedback_widget { + z-index: 9999; +} /* */ /* Summary bar component */ diff --git a/testgen/ui/static/js/components/help_menu.js b/testgen/ui/static/js/components/help_menu.js index 2a2fd9cb..33d04f44 100644 --- a/testgen/ui/static/js/components/help_menu.js +++ b/testgen/ui/static/js/components/help_menu.js @@ -8,6 +8,7 @@ * @typedef Permissions * @type {object} * @property {boolean} can_edit + * @property {boolean} is_logged_in * * @typedef Properties * @type {object} @@ -71,6 +72,13 @@ const HelpMenu = (/** @type Properties */ props) => { ) : null, span({ class: 'help-divider' }), + getValue(props.permissions)?.is_logged_in + ? div( + { class: 'help-item', onclick: () => emit('FeedbackClicked') }, + Icon({ classes: 'help-item-icon' }, 'rate_review'), + 'Give Feedback', + ) + : null, HelpLink(slackUrl, 'Slack Community', 'group'), getValue(props.support_email) ? HelpLink( From 8e5d3aea38ffd3dce786b7f00a52f7108fd4691c Mon Sep 17 00:00:00 2001 From: Luis Date: Tue, 19 May 2026 19:42:47 -0400 Subject: [PATCH 45/58] feat(mcp): add CRUD mcp tools for notifications - list_notifications - get_notification - create_notification - update_notification - delete_notification --- deploy/build_mcp_docs.py | 2 + .../common/models/notification_settings.py | 167 +- testgen/common/models/scores.py | 13 + testgen/mcp/server.py | 12 + testgen/mcp/tools/common.py | 115 + testgen/mcp/tools/notifications.py | 995 +++++++ .../models/test_notification_settings.py | 139 + .../common/models/test_score_definition.py | 39 + tests/unit/mcp/test_tools_notifications.py | 2318 +++++++++++++++++ 9 files changed, 3784 insertions(+), 16 deletions(-) create mode 100644 testgen/mcp/tools/notifications.py create mode 100644 tests/unit/common/models/test_notification_settings.py create mode 100644 tests/unit/mcp/test_tools_notifications.py diff --git a/deploy/build_mcp_docs.py b/deploy/build_mcp_docs.py index 2040820d..0fd01c19 100644 --- a/deploy/build_mcp_docs.py +++ b/deploy/build_mcp_docs.py @@ -31,6 +31,8 @@ DocGroup.INVESTIGATE, DocGroup.BROWSE_PROFILING, DocGroup.TRIGGER, + DocGroup.SCORING, + DocGroup.MANAGE, ] _FALLBACK_GROUP = "Other tools" diff --git a/testgen/common/models/notification_settings.py b/testgen/common/models/notification_settings.py index e15349f4..90063422 100644 --- a/testgen/common/models/notification_settings.py +++ b/testgen/common/models/notification_settings.py @@ -1,6 +1,7 @@ import enum import re from collections.abc import Iterable +from dataclasses import dataclass from decimal import Decimal from typing import ClassVar, Generic, Self, TypeVar from uuid import UUID, uuid4 @@ -8,6 +9,7 @@ from sqlalchemy import Boolean, Column, Enum, ForeignKey, String, and_, or_, select from sqlalchemy.dialects import postgresql from sqlalchemy.sql import Select +from sqlalchemy.sql.elements import ColumnElement from testgen.common.models import get_current_session from testgen.common.models.custom_types import JSON_TYPE @@ -22,6 +24,17 @@ TriggerT = TypeVar("TriggerT", bound=Enum) +_EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$") + + +def is_valid_email(value: str) -> bool: + """Return whether ``value`` is a well-formed email address. + + Single source of truth for recipient validation, shared by the model's + ``validate()`` and the MCP layer's batch recipient check. + """ + return bool(_EMAIL_REGEX.match(value)) + class TestRunNotificationTrigger(enum.Enum): always = "always" @@ -51,6 +64,27 @@ class NotificationSettingsValidationError(Exception): pass +@dataclass +class NotificationSummary: + """Row shape for paginated ``NotificationSettings.list_for_*`` queries. + + Field order matches the SELECT projection in the ``list_for_*`` methods. + ``settings`` keeps the raw JSONB blob so event-specific values (``trigger``, + ``total_threshold``, ``cde_threshold``, ``table_name``) can be read by the + consumer's format helpers without forking the dataclass per event type. + """ + + id: UUID + project_code: str + event: NotificationEvent + enabled: bool + recipients: list[str] + test_suite_id: UUID | None + table_group_id: UUID | None + score_definition_id: UUID | None + settings: dict + + class NotificationSettings(Entity): __tablename__ = "notification_settings" @@ -87,6 +121,17 @@ class NotificationSettings(Entity): "polymorphic_identity": "base", } + @classmethod + def _scope_subquery(cls, entity, rel_col, id_value) -> ColumnElement[bool]: + """Where-clause: rows scoped to ``entity.id == id_value`` plus project-wide rows + (``rel_col IS NULL``) for the same project. Used by both the streaming + ``select()`` and the paginated ``list_for_*`` methods. + """ + return and_( + cls.project_code.in_(select(entity.project_code).where(entity.id == id_value)), + or_(rel_col == id_value, rel_col.is_(None)), + ) + @classmethod def _base_select_query( cls, @@ -94,9 +139,9 @@ def _base_select_query( enabled: bool | SENTINEL_TYPE = SENTINEL, event: NotificationEvent | SENTINEL_TYPE = SENTINEL, project_code: str | SENTINEL_TYPE = SENTINEL, - test_suite_id: UUID | None | SENTINEL_TYPE = SENTINEL, - table_group_id: UUID | None | SENTINEL_TYPE = SENTINEL, - score_definition_id: UUID | None | SENTINEL_TYPE = SENTINEL, + test_suite_id: UUID | SENTINEL_TYPE | None = SENTINEL, + table_group_id: UUID | SENTINEL_TYPE | None = SENTINEL, + score_definition_id: UUID | SENTINEL_TYPE | None = SENTINEL, ) -> Select: fk_count = len([None for fk in (test_suite_id, table_group_id, score_definition_id) if fk is not SENTINEL]) if fk_count > 1: @@ -112,18 +157,12 @@ def _base_select_query( if project_code is not SENTINEL: query = query.where(cls.project_code == project_code) - def _subquery_clauses(entity, rel_col, id_value): - return and_( - cls.project_code.in_(select(entity.project_code).where(entity.id == id_value)), - or_(rel_col == id_value, rel_col.is_(None)), - ) - if test_suite_id is not SENTINEL: - query = query.where(_subquery_clauses(TestSuite, cls.test_suite_id, test_suite_id)) + query = query.where(cls._scope_subquery(TestSuite, cls.test_suite_id, test_suite_id)) elif table_group_id is not SENTINEL: - query = query.where(_subquery_clauses(TableGroup, cls.table_group_id, table_group_id)) + query = query.where(cls._scope_subquery(TableGroup, cls.table_group_id, table_group_id)) elif score_definition_id is not SENTINEL: - query = query.where(_subquery_clauses(ScoreDefinition, cls.score_definition_id, score_definition_id)) + query = query.where(cls._scope_subquery(ScoreDefinition, cls.score_definition_id, score_definition_id)) return query @@ -134,9 +173,9 @@ def select( enabled: bool | SENTINEL_TYPE = SENTINEL, event: NotificationEvent | SENTINEL_TYPE = SENTINEL, project_code: str | SENTINEL_TYPE = SENTINEL, - test_suite_id: UUID | None | SENTINEL_TYPE = SENTINEL, - table_group_id: UUID | None | SENTINEL_TYPE = SENTINEL, - score_definition_id: UUID | None | SENTINEL_TYPE = SENTINEL, + test_suite_id: UUID | SENTINEL_TYPE | None = SENTINEL, + table_group_id: UUID | SENTINEL_TYPE | None = SENTINEL, + score_definition_id: UUID | SENTINEL_TYPE | None = SENTINEL, ) -> Iterable[Self]: query = cls._base_select_query( enabled=enabled, @@ -150,6 +189,102 @@ def select( ) return get_current_session().scalars(query) + @classmethod + def _list_query(cls, scope_clause) -> Select: + """Projection + ORDER BY shared by every ``list_for_*`` classmethod. + + ``scope_clause`` is the WHERE expression that narrows to a project or a parent + entity (and its project-wide siblings). Caller-supplied filters arrive as + ``*clauses`` in each ``list_for_*`` wrapper and are appended here. + """ + return ( + select( + cls.id.label("id"), + cls.project_code.label("project_code"), + cls.event.label("event"), + cls.enabled.label("enabled"), + cls.recipients.label("recipients"), + cls.test_suite_id.label("test_suite_id"), + cls.table_group_id.label("table_group_id"), + cls.score_definition_id.label("score_definition_id"), + cls.settings.label("settings"), + ) + .where(scope_clause) + .order_by( + cls.project_code, cls.event, cls.test_suite_id, + cls.table_group_id, cls.score_definition_id, cls.id, + ) + ) + + @classmethod + def list_for_projects( + cls, + project_codes: Iterable[str], + *clauses, + page: int = 1, + limit: int = 50, + ) -> tuple[list[NotificationSummary], int]: + """Paginated notifications across one or more projects.""" + query = cls._list_query(cls.project_code.in_(list(project_codes))) + if clauses: + query = query.where(*clauses) + return cls._paginate(query, page=page, limit=limit, data_class=NotificationSummary) + + @classmethod + def list_for_test_suite( + cls, + test_suite_id: UUID, + *clauses, + page: int = 1, + limit: int = 50, + ) -> tuple[list[NotificationSummary], int]: + """Paginated notifications whose ``test_suite_id`` exactly matches ``test_suite_id``. + + Use ``list_for_projects`` to also surface project-wide notifications (rows + with ``test_suite_id IS NULL``) — they're a different display concern from + narrowing to a specific suite. + """ + query = cls._list_query(cls.test_suite_id == test_suite_id) + if clauses: + query = query.where(*clauses) + return cls._paginate(query, page=page, limit=limit, data_class=NotificationSummary) + + @classmethod + def list_for_table_group( + cls, + table_group_id: UUID, + *clauses, + page: int = 1, + limit: int = 50, + ) -> tuple[list[NotificationSummary], int]: + """Paginated notifications whose ``table_group_id`` exactly matches ``table_group_id``. + + Use ``list_for_projects`` to also surface project-wide notifications (rows + with ``table_group_id IS NULL``). + """ + query = cls._list_query(cls.table_group_id == table_group_id) + if clauses: + query = query.where(*clauses) + return cls._paginate(query, page=page, limit=limit, data_class=NotificationSummary) + + @classmethod + def list_for_score_definition( + cls, + score_definition_id: UUID, + *clauses, + page: int = 1, + limit: int = 50, + ) -> tuple[list[NotificationSummary], int]: + """Paginated notifications whose ``score_definition_id`` exactly matches ``score_definition_id``. + + Use ``list_for_projects`` to also surface project-wide notifications (rows + with ``score_definition_id IS NULL``). + """ + query = cls._list_query(cls.score_definition_id == score_definition_id) + if clauses: + query = query.where(*clauses) + return cls._paginate(query, page=page, limit=limit, data_class=NotificationSummary) + def _validate_settings(self): pass @@ -157,7 +292,7 @@ def validate(self): if len(self.recipients) < 1: raise NotificationSettingsValidationError("At least one recipient must be defined.") for addr in self.recipients: - if not re.match(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$", addr): + if not is_valid_email(addr): raise NotificationSettingsValidationError(f"Invalid email address: {addr}.") self._validate_settings() diff --git a/testgen/common/models/scores.py b/testgen/common/models/scores.py index e2df3df7..4dfb4c10 100644 --- a/testgen/common/models/scores.py +++ b/testgen/common/models/scores.py @@ -157,6 +157,19 @@ def get(cls, id_: str) -> Self | None: definition = db_session.scalars(query).first() return definition + @classmethod + def names_by_id(cls, ids: Iterable[UUID]) -> dict[UUID, str]: + """Return ``{id: name}`` for the given scorecard IDs in a single query. + + IDs with no matching scorecard are omitted. Empty input yields ``{}`` + without touching the database. + """ + ids = list(ids) + if not ids: + return {} + query = select(cls.id, cls.name).where(cls.id.in_(ids)) + return {row.id: row.name for row in get_current_session().execute(query).all()} + @classmethod def list_with_table_group_targets( cls, diff --git a/testgen/mcp/server.py b/testgen/mcp/server.py index a29e84b6..5e91b2f1 100644 --- a/testgen/mcp/server.py +++ b/testgen/mcp/server.py @@ -151,6 +151,13 @@ def build_mcp_server( search_hygiene_issues, update_hygiene_issue, ) + from testgen.mcp.tools.notifications import ( + create_notification, + delete_notification, + get_notification, + list_notifications, + update_notification, + ) from testgen.mcp.tools.profile_history import ( compare_profiling_runs, get_profiling_trends, @@ -298,6 +305,11 @@ def safe_prompt(fn): safe_tool(create_scorecard) safe_tool(update_scorecard) safe_tool(delete_scorecard) + safe_tool(list_notifications) + safe_tool(get_notification) + safe_tool(create_notification) + safe_tool(update_notification) + safe_tool(delete_notification) # Resources safe_resource("testgen://test-types", test_types_resource) diff --git a/testgen/mcp/tools/common.py b/testgen/mcp/tools/common.py index ea472419..e08ca861 100644 --- a/testgen/mcp/tools/common.py +++ b/testgen/mcp/tools/common.py @@ -15,6 +15,13 @@ SuggestedDataType, ) from testgen.common.models.hygiene_issue import HygieneIssueType +from testgen.common.models.notification_settings import ( + MonitorNotificationTrigger, + NotificationEvent, + NotificationSettings, + ProfilingRunNotificationTrigger, + TestRunNotificationTrigger, +) from testgen.common.models.profiling_run import ProfilingRun from testgen.common.models.scheduler import SCHEDULABLE_JOB_KEYS, JobSchedule from testgen.common.models.scores import ScoreCategory, ScoreDefinition @@ -49,6 +56,7 @@ class DocGroup(StrEnum): BROWSE_PROFILING = "Browse profiling results" TRIGGER = "Trigger profiling, tests, and test generation" SCORING = "Track data quality scores" + MANAGE = "Manage TestGen configuration" def parse_uuid(value: str, label: str = "ID") -> UUID: @@ -609,3 +617,110 @@ def resolve_schedule(schedule_id: str) -> JobSchedule: if sched is None: raise MCPResourceNotAccessible("Schedule", schedule_id) return sched + + +def resolve_notification(notification_id: str) -> NotificationSettings: + """Resolve a notification ID, collapsing missing-or-inaccessible into one error path. + + Returns the polymorphic ``NotificationSettings`` subclass (TestRun / ProfilingRun / + ScoreDrop / Monitor) so callers can read event-specific typed properties. + """ + notif_uuid = parse_uuid(notification_id, "notification_id") + perms = get_project_permissions() + notif = NotificationSettings.get( + notif_uuid, + NotificationSettings.project_code.in_(perms.allowed_codes), + ) + if notif is None: + raise MCPResourceNotAccessible("Notification", notification_id) + return notif + + +# Notification event-type labels. + +class NotificationEventLabel(StrEnum): + """User-facing values for notification event types.""" + + TEST_RUN = "Test Run" + PROFILING_RUN = "Profiling Run" + SCORE_DROP = "Score Drop" + MONITOR_RUN = "Monitor Alert" + + +NOTIFICATION_EVENT_LABEL_TO_INTERNAL: dict[NotificationEventLabel, NotificationEvent] = { + NotificationEventLabel.TEST_RUN: NotificationEvent.test_run, + NotificationEventLabel.PROFILING_RUN: NotificationEvent.profiling_run, + NotificationEventLabel.SCORE_DROP: NotificationEvent.score_drop, + NotificationEventLabel.MONITOR_RUN: NotificationEvent.monitor_run, +} + +_NOTIFICATION_EVENT_INTERNAL_TO_LABEL: dict[NotificationEvent, NotificationEventLabel] = { + v: k for k, v in NOTIFICATION_EVENT_LABEL_TO_INTERNAL.items() +} + + +def format_notification_event(event: NotificationEvent | str) -> str: + """Map a stored notification event to its user-facing label.""" + return _NOTIFICATION_EVENT_INTERNAL_TO_LABEL[NotificationEvent(event)].value + + +# Notification trigger labels — one StrEnum per event type. Same wording the end user sees in the UI: +# ``ui/views/test_runs.py:249-254``, ``ui/views/profiling_runs.py:265-268``, +# ``ui/views/monitors_dashboard.py:323-326``. + +class TestRunTriggerLabel(StrEnum): + ALWAYS = "Always" + ON_FAILURES = "On test failures" + ON_WARNINGS = "On test failures and warnings" + ON_CHANGES = "On new test failures and warnings" + + +TEST_RUN_TRIGGER_LABEL_TO_INTERNAL: dict[TestRunTriggerLabel, TestRunNotificationTrigger] = { + TestRunTriggerLabel.ALWAYS: TestRunNotificationTrigger.always, + TestRunTriggerLabel.ON_FAILURES: TestRunNotificationTrigger.on_failures, + TestRunTriggerLabel.ON_WARNINGS: TestRunNotificationTrigger.on_warnings, + TestRunTriggerLabel.ON_CHANGES: TestRunNotificationTrigger.on_changes, +} + + +class ProfilingRunTriggerLabel(StrEnum): + ALWAYS = "Always" + ON_CHANGES = "On new hygiene issues" + + +PROFILING_RUN_TRIGGER_LABEL_TO_INTERNAL: dict[ProfilingRunTriggerLabel, ProfilingRunNotificationTrigger] = { + ProfilingRunTriggerLabel.ALWAYS: ProfilingRunNotificationTrigger.always, + ProfilingRunTriggerLabel.ON_CHANGES: ProfilingRunNotificationTrigger.on_changes, +} + + +class MonitorTriggerLabel(StrEnum): + ON_ANOMALIES = "On anomalies" + + +MONITOR_TRIGGER_LABEL_TO_INTERNAL: dict[MonitorTriggerLabel, MonitorNotificationTrigger] = { + MonitorTriggerLabel.ON_ANOMALIES: MonitorNotificationTrigger.on_anomalies, +} + +_TEST_RUN_TRIGGER_INTERNAL_TO_LABEL = {v: k for k, v in TEST_RUN_TRIGGER_LABEL_TO_INTERNAL.items()} +_PROFILING_RUN_TRIGGER_INTERNAL_TO_LABEL = {v: k for k, v in PROFILING_RUN_TRIGGER_LABEL_TO_INTERNAL.items()} +_MONITOR_TRIGGER_INTERNAL_TO_LABEL = {v: k for k, v in MONITOR_TRIGGER_LABEL_TO_INTERNAL.items()} + + +def format_notification_trigger(event: NotificationEvent | str, settings: dict | None) -> str | None: + """Map a notification's stored trigger value to its user-facing label. + + Returns ``None`` for ``score_drop`` (no trigger — thresholds drive it) or when + ``settings`` carries no ``trigger`` key. + """ + raw = settings.get("trigger") if settings else None + if raw is None: + return None + event_enum = NotificationEvent(event) + if event_enum is NotificationEvent.test_run: + return _TEST_RUN_TRIGGER_INTERNAL_TO_LABEL[TestRunNotificationTrigger(raw)].value + if event_enum is NotificationEvent.profiling_run: + return _PROFILING_RUN_TRIGGER_INTERNAL_TO_LABEL[ProfilingRunNotificationTrigger(raw)].value + if event_enum is NotificationEvent.monitor_run: + return _MONITOR_TRIGGER_INTERNAL_TO_LABEL[MonitorNotificationTrigger(raw)].value + return None diff --git a/testgen/mcp/tools/notifications.py b/testgen/mcp/tools/notifications.py new file mode 100644 index 00000000..5b798ca6 --- /dev/null +++ b/testgen/mcp/tools/notifications.py @@ -0,0 +1,995 @@ +from dataclasses import dataclass +from decimal import Decimal +from enum import StrEnum +from uuid import UUID + +from testgen.common.models import with_database_session +from testgen.common.models.notification_settings import ( + MonitorNotificationTrigger, + NotificationEvent, + NotificationSettings, + NotificationSummary, + ProfilingRunNotificationSettings, + ProfilingRunNotificationTrigger, + ScoreDropNotificationSettings, + TestRunNotificationSettings, + TestRunNotificationTrigger, + is_valid_email, +) +from testgen.common.models.scores import ScoreDefinition +from testgen.common.models.table_group import TableGroup +from testgen.common.models.test_suite import TestSuite +from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError +from testgen.mcp.permissions import get_project_permissions, mcp_permission +from testgen.mcp.tools.common import ( + MONITOR_TRIGGER_LABEL_TO_INTERNAL, + NOTIFICATION_EVENT_LABEL_TO_INTERNAL, + PROFILING_RUN_TRIGGER_LABEL_TO_INTERNAL, + TEST_RUN_TRIGGER_LABEL_TO_INTERNAL, + DocGroup, + MonitorTriggerLabel, + NotificationEventLabel, + ProfilingRunTriggerLabel, + TestRunTriggerLabel, + format_notification_event, + format_notification_trigger, + format_page_footer, + format_page_info, + resolve_notification, + resolve_scorecard, + resolve_table_group, + resolve_test_suite, + validate_limit, + validate_page, +) +from testgen.mcp.tools.markdown import MdDoc + +_DOC_GROUP = DocGroup.MANAGE + +# ``Monitor Alert`` is intentionally excluded from creation: a monitor notification is +# bound to its (internal, user-invisible) monitor test suite at monitor-setup time, so it +# can't be created standalone here. Existing monitor notifications are still managed via +# get/update/delete/list_notifications. +_CREATE_SUPPORTED_EVENTS: tuple[NotificationEvent, ...] = ( + NotificationEvent.test_run, + NotificationEvent.profiling_run, + NotificationEvent.score_drop, +) + + +@with_database_session +@mcp_permission("view") +def list_notifications( + project_code: str | None = None, + test_suite_id: str | None = None, + table_group_id: str | None = None, + scorecard_id: str | None = None, + limit: int = 50, + page: int = 1, +) -> str: + """List notifications configured across projects, or scoped to a parent entity. + + With no scope argument, returns notifications across every project the caller can view. + Provide one of ``project_code`` / ``test_suite_id`` / ``table_group_id`` / ``scorecard_id`` + to narrow the listing. Parent-entity scopes filter strictly on that entity — to also + see project-wide notifications (those not bound to a specific suite, table group, or + scorecard), use ``project_code``. + + Args: + project_code: Scope to a specific project. + test_suite_id: UUID of a test suite, e.g. from ``list_test_suites``. Returns only + notifications bound to this suite. + table_group_id: UUID of a table group, e.g. from ``get_data_inventory``. Returns + only notifications bound to this table group. + scorecard_id: UUID of a scorecard, e.g. from ``list_scorecards``. Returns only + notifications bound to this scorecard. + limit: Maximum number of notifications per page (default 50, max 200). + page: Page number, starting from 1 (default 1). + """ + validate_page(page) + validate_limit(limit, 200) + + scope_args = { + "project_code": project_code, + "test_suite_id": test_suite_id, + "table_group_id": table_group_id, + "scorecard_id": scorecard_id, + } + provided = [name for name, value in scope_args.items() if value] + if len(provided) > 1: + raise MCPUserError( + "Pass at most one of `project_code`, `test_suite_id`, `table_group_id`, `scorecard_id`." + ) + + perms = get_project_permissions() + scope_label: str | None = None + + if test_suite_id: + suite = resolve_test_suite(test_suite_id) + rows, total = NotificationSettings.list_for_test_suite(suite.id, page=page, limit=limit) + scope_label = f"Test Suite `{suite.test_suite}`" + elif table_group_id: + tg = resolve_table_group(table_group_id) + rows, total = NotificationSettings.list_for_table_group(tg.id, page=page, limit=limit) + scope_label = f"Table Group `{tg.table_groups_name}`" + elif scorecard_id: + scorecard = resolve_scorecard(scorecard_id) + rows, total = NotificationSettings.list_for_score_definition(scorecard.id, page=page, limit=limit) + scope_label = f"Scorecard `{scorecard.name}`" + elif project_code: + perms.verify_access(project_code, not_found=MCPResourceNotAccessible("Project", project_code)) + rows, total = NotificationSettings.list_for_projects([project_code], page=page, limit=limit) + scope_label = f"Project `{project_code}`" + else: + rows, total = NotificationSettings.list_for_projects(perms.allowed_codes, page=page, limit=limit) + + return _render(rows, total, page=page, limit=limit, scope_label=scope_label) + + +@with_database_session +@mcp_permission("view") +def get_notification(notification_id: str) -> str: + """Get full details of an email notification: event type, trigger or thresholds, + scope (project, test suite, table group, or scorecard), and recipients. + + Works on any notification, including ``Monitor Alert`` notifications — those are + created through monitor setup rather than this tool, but can be viewed here. + + Args: + notification_id: UUID of the notification, e.g. from ``list_notifications``. + """ + notif = resolve_notification(notification_id) + return _render_one(notif) + + +@with_database_session +@mcp_permission("edit") +def create_notification( + event_type: str, + recipients: list[str], + test_suite_id: str | None = None, + table_group_id: str | None = None, + scorecard_id: str | None = None, + trigger_on: str | None = None, + total_threshold: float | None = None, + cde_threshold: float | None = None, +) -> str: + """Create an email notification for a test-run, profiling-run, or score-drop event. + + Every invalid input is surfaced in a single error so the call can be corrected + in one round-trip — no partial save occurs. + + Args: + event_type: The event that triggers the notification. One of + ``Test Run``, ``Profiling Run``, ``Score Drop``. ``Monitor Alert`` + notifications are configured in the TestGen UI and cannot be created + here; ``update_notification`` can still modify them once they exist. + recipients: One or more well-formed email addresses to notify. + test_suite_id: UUID of the test suite, e.g. from ``list_test_suites``. + Required when ``event_type`` is ``Test Run``; rejected otherwise. + table_group_id: UUID of the table group, e.g. from ``get_data_inventory``. + Required when ``event_type`` is ``Profiling Run``; rejected otherwise. + scorecard_id: UUID of the scorecard, e.g. from ``list_scorecards``. + Required when ``event_type`` is ``Score Drop``; rejected otherwise. + trigger_on: When to fire the notification. Only used for ``Test Run`` + and ``Profiling Run``; rejected for ``Score Drop``. + For ``Test Run`` (default ``On test failures``): one of ``Always``, + ``On test failures``, ``On test failures and warnings``, + ``On new test failures and warnings``. + For ``Profiling Run`` (default ``On new hygiene issues``): one of + ``Always``, ``On new hygiene issues``. + total_threshold: Score-drop trigger for the total score (over 0, up to 100). + Only used for ``Score Drop``; at least one of ``total_threshold`` or + ``cde_threshold`` must be supplied. + cde_threshold: Score-drop trigger for the critical-data-element score + (over 0, up to 100). Only used for ``Score Drop``. + """ + event = _parse_event_type(event_type) + + if event is NotificationEvent.test_run: + _enforce_scope_shape( + event_type, + required=("test_suite_id", test_suite_id), + forbidden=(("table_group_id", table_group_id), ("scorecard_id", scorecard_id)), + ) + _reject_threshold_args(event_type, total_threshold, cde_threshold) + suite = resolve_test_suite(test_suite_id) + clean_recipients = _validate_recipients(recipients) + trigger = _parse_test_run_trigger(trigger_on) + notif = TestRunNotificationSettings.create( + project_code=suite.project_code, + test_suite_id=suite.id, + recipients=clean_recipients, + trigger=trigger, + ) + elif event is NotificationEvent.profiling_run: + _enforce_scope_shape( + event_type, + required=("table_group_id", table_group_id), + forbidden=(("test_suite_id", test_suite_id), ("scorecard_id", scorecard_id)), + ) + _reject_threshold_args(event_type, total_threshold, cde_threshold) + tg = resolve_table_group(table_group_id) + clean_recipients = _validate_recipients(recipients) + trigger = _parse_profiling_run_trigger(trigger_on) + notif = ProfilingRunNotificationSettings.create( + project_code=tg.project_code, + table_group_id=tg.id, + recipients=clean_recipients, + trigger=trigger, + ) + else: + # NotificationEvent.score_drop — _parse_event_type rejected anything else. + _enforce_scope_shape( + event_type, + required=("scorecard_id", scorecard_id), + forbidden=(("test_suite_id", test_suite_id), ("table_group_id", table_group_id)), + ) + if trigger_on is not None: + raise MCPUserError( + f"`trigger_on` is not supported for event type `{event_type}` — thresholds drive the event." + ) + scorecard = resolve_scorecard(scorecard_id) + _validate_score_thresholds(total_threshold, cde_threshold) + clean_recipients = _validate_recipients(recipients) + notif = ScoreDropNotificationSettings.create( + project_code=scorecard.project_code, + score_definition_id=scorecard.id, + recipients=clean_recipients, + total_score_threshold=total_threshold, + cde_score_threshold=cde_threshold, + ) + + return _render_created(notif) + + +# --- create_notification helpers --- + + +def _parse_event_type(value: str) -> NotificationEvent: + """Map the supplied display label to its ``NotificationEvent``. + + Rejects anything outside the create-supported subset (test_run / profiling_run / + score_drop) — including the otherwise-valid ``Monitor Alert`` event. Raises + ``MCPUserError`` listing every supported display label. + """ + label: NotificationEventLabel | None + try: + label = NotificationEventLabel(value) + except ValueError: + label = None + event = NOTIFICATION_EVENT_LABEL_TO_INTERNAL.get(label) if label is not None else None + if event not in _CREATE_SUPPORTED_EVENTS: + valid = ", ".join(f"`{format_notification_event(e)}`" for e in _CREATE_SUPPORTED_EVENTS) + raise MCPUserError(f"Invalid `event_type` `{value}`. Valid values: {valid}.") + return event + + +def _enforce_scope_shape( + event_type: str, + *, + required: tuple[str, str | None], + forbidden: tuple[tuple[str, str | None], ...], +) -> None: + """Reject missing-required or any forbidden scope args for the chosen event.""" + required_name, required_value = required + if not required_value: + raise MCPUserError(f"`{required_name}` is required for event type `{event_type}`.") + supplied_forbidden = [name for name, value in forbidden if value] + if supplied_forbidden: + joined = ", ".join(f"`{name}`" for name in supplied_forbidden) + raise MCPUserError(f"{joined} not supported for event type `{event_type}`. Use only `{required_name}`.") + + +def _reject_threshold_args( + event_type: str, + total_threshold: float | None, + cde_threshold: float | None, +) -> None: + """Reject ``total_threshold`` / ``cde_threshold`` on non-score events.""" + stray = [ + name + for name, value in ( + ("total_threshold", total_threshold), + ("cde_threshold", cde_threshold), + ) + if value is not None + ] + if stray: + joined = ", ".join(f"`{name}`" for name in stray) + raise MCPUserError( + f"{joined} not supported for event type `{event_type}`. " + "Only `Score Drop` notifications use score thresholds." + ) + + +def _validate_recipients(recipients: list[str]) -> list[str]: + """Return the recipients list after batch-validating every entry. + + Raises ``MCPUserError`` if the list is empty or contains any malformed address — + every bad address is named in the single error message so the caller can fix + them all in one round-trip. + """ + if not recipients: + raise MCPUserError("`recipients` must contain at least one email address.") + invalid = [addr for addr in recipients if not is_valid_email(addr)] + if invalid: + joined = ", ".join(f"`{addr}`" for addr in invalid) + raise MCPUserError(f"Invalid email addresses: {joined}.") + return list(recipients) + + +def _parse_test_run_trigger(value: str | None) -> TestRunNotificationTrigger: + if value is None: + return TestRunNotificationTrigger.on_failures + try: + label = TestRunTriggerLabel(value) + except ValueError as err: + valid = ", ".join(f"`{label.value}`" for label in TestRunTriggerLabel) + raise MCPUserError( + f"Invalid `trigger_on` `{value}` for event type `Test Run`. Valid values: {valid}." + ) from err + return TEST_RUN_TRIGGER_LABEL_TO_INTERNAL[label] + + +def _parse_profiling_run_trigger(value: str | None) -> ProfilingRunNotificationTrigger: + if value is None: + return ProfilingRunNotificationTrigger.on_changes + try: + label = ProfilingRunTriggerLabel(value) + except ValueError as err: + valid = ", ".join(f"`{label.value}`" for label in ProfilingRunTriggerLabel) + raise MCPUserError( + f"Invalid `trigger_on` `{value}` for event type `Profiling Run`. Valid values: {valid}." + ) from err + return PROFILING_RUN_TRIGGER_LABEL_TO_INTERNAL[label] + + +def _validate_score_thresholds( + total_threshold: float | None, + cde_threshold: float | None, +) -> None: + """Reject missing-or-out-of-range thresholds for a score-drop notification. + + Surfaces every range violation in a single error. + """ + if total_threshold is None and cde_threshold is None: + raise MCPUserError( + "At least one of `total_threshold` or `cde_threshold` must be set for event type `Score Drop`." + ) + _validate_threshold_range(total_threshold, cde_threshold) + + +def _validate_threshold_range( + total_threshold: float | None, + cde_threshold: float | None, +) -> None: + """Reject any out-of-range threshold value; surface every offender in one error. + + 0 is rejected: a score can never drop below 0, so a 0 threshold would never fire. + """ + range_errors = [] + for name, value in (("total_threshold", total_threshold), ("cde_threshold", cde_threshold)): + if value is not None and not 0 < value <= 100: + range_errors.append(f"`{name}` = {value} (must be greater than 0 and at most 100)") + if range_errors: + raise MCPUserError("Score threshold out of range: " + "; ".join(range_errors) + ".") + + +def _parse_monitor_trigger(value: str | None) -> MonitorNotificationTrigger: + if value is None: + return MonitorNotificationTrigger.on_anomalies + try: + label = MonitorTriggerLabel(value) + except ValueError as err: + valid = ", ".join(f"`{label.value}`" for label in MonitorTriggerLabel) + raise MCPUserError( + f"Invalid `trigger_on` `{value}` for event type `Monitor Alert`. Valid values: {valid}." + ) from err + return MONITOR_TRIGGER_LABEL_TO_INTERNAL[label] + + +@with_database_session +@mcp_permission("edit") +def update_notification( + notification_id: str, + *, + enabled: bool | None = None, + recipients: list[str] | None = None, + trigger_on: str | None = None, + total_threshold: float | None = None, + cde_threshold: float | None = None, + clear_total_threshold: bool = False, + clear_cde_threshold: bool = False, + table_name: str | None = None, + clear_table_name: bool = False, +) -> str: + """Update fields on an existing email notification. Pass only the fields to change. + + Works on any notification, including ``Monitor Alert`` notifications — those are + created through monitor setup rather than this tool, but can be updated here. + + Every invalid input surfaces in a single error before any save — no partial save. + The notification's event type and scope entity are immutable through this tool; + delete and recreate to change them. (A Monitor Alert's optional table — a finer + scope within its table group — can still be set or cleared here.) + + Args: + notification_id: UUID of the notification, e.g. from ``list_notifications``. + enabled: ``True`` to resume, ``False`` to pause. Omit to leave unchanged. + recipients: Replace the recipient list with the supplied addresses (one or more + well-formed emails). Omit to leave unchanged. + trigger_on: New trigger condition. Only valid for ``Test Run``, ``Profiling Run``, + and ``Monitor Alert`` notifications; rejected for ``Score Drop``. + For ``Test Run``: one of ``Always``, ``On test failures``, + ``On test failures and warnings``, ``On new test failures and warnings``. + For ``Profiling Run``: one of ``Always``, ``On new hygiene issues``. + For ``Monitor Alert``: ``On anomalies`` is the only supported value, so + this field cannot meaningfully be changed on Monitor Alert notifications. + total_threshold: New total score threshold (over 0, up to 100). Only valid for + ``Score Drop`` notifications. + cde_threshold: New critical-data-element score threshold (over 0, up to 100). Only valid + for ``Score Drop`` notifications. + clear_total_threshold: ``True`` to clear the overall-score threshold (set to + NULL). At least one threshold must remain set after the call. + clear_cde_threshold: ``True`` to clear the CDE-score threshold. At least one + threshold must remain set after the call. + table_name: Narrow a Monitor Alert notification's scope to a single table within + its table group. Only valid for ``Monitor Alert`` notifications. + clear_table_name: ``True`` to drop an existing table from a Monitor Alert + notification (notifications then fire for any table in the table group). + """ + if ( + enabled is None + and recipients is None + and trigger_on is None + and total_threshold is None + and cde_threshold is None + and not clear_total_threshold + and not clear_cde_threshold + and table_name is None + and not clear_table_name + ): + raise MCPUserError("No fields supplied to update.") + + notif = resolve_notification(notification_id) + event = notif.event + event_label = format_notification_event(event) + + _reject_event_stray_args( + event, + event_label, + trigger_on=trigger_on, + total_threshold=total_threshold, + cde_threshold=cde_threshold, + clear_total_threshold=clear_total_threshold, + clear_cde_threshold=clear_cde_threshold, + table_name=table_name, + clear_table_name=clear_table_name, + ) + + _reject_set_and_clear_conflicts( + total_threshold=total_threshold, + clear_total_threshold=clear_total_threshold, + cde_threshold=cde_threshold, + clear_cde_threshold=clear_cde_threshold, + table_name=table_name, + clear_table_name=clear_table_name, + ) + + clean_recipients: list[str] | None = None + if recipients is not None: + clean_recipients = _validate_recipients(recipients) + + parsed_trigger = None + if trigger_on is not None: + if event is NotificationEvent.test_run: + parsed_trigger = _parse_test_run_trigger(trigger_on) + elif event is NotificationEvent.profiling_run: + parsed_trigger = _parse_profiling_run_trigger(trigger_on) + elif event is NotificationEvent.monitor_run: + parsed_trigger = _parse_monitor_trigger(trigger_on) + + if event is NotificationEvent.score_drop: + _validate_threshold_range(total_threshold, cde_threshold) + _validate_score_drop_post_state( + notif, + total_threshold=total_threshold, + cde_threshold=cde_threshold, + clear_total_threshold=clear_total_threshold, + clear_cde_threshold=clear_cde_threshold, + ) + + pending = _build_pending( + notif, + enabled=enabled, + recipients=clean_recipients, + trigger=parsed_trigger, + total_threshold=total_threshold, + cde_threshold=cde_threshold, + clear_total_threshold=clear_total_threshold, + clear_cde_threshold=clear_cde_threshold, + table_name=table_name, + clear_table_name=clear_table_name, + ) + + doc = MdDoc() + doc.heading(1, f"{event_label} Notification updated") + doc.field("Notification ID", notif.id, code=True) + + if not pending: + doc.text("No fields changed — supplied values matched the current state.") + return doc.render() + + before = {attr: _snapshot_attr(notif, attr) for attr in pending} + for attr, value in pending.items(): + setattr(notif, attr, value) + after = {attr: _snapshot_attr(notif, attr) for attr in pending} + + notif.save() + + rows = [[_DIFF_LABELS[attr], before[attr], after[attr]] for attr in pending] + doc.table(["Field", "Before", "After"], rows) + return doc.render() + + +# --- update_notification helpers --- + + +_DIFF_LABELS: dict[str, str] = { + "enabled": "Status", + "recipients": "Recipients", + "trigger": "Trigger", + "total_score_threshold": "Total Score Threshold", + "cde_score_threshold": "CDE Score Threshold", + "table_name": "Table", +} + + +def _reject_event_stray_args( + event: NotificationEvent, + event_label: str, + *, + trigger_on: str | None, + total_threshold: float | None, + cde_threshold: float | None, + clear_total_threshold: bool, + clear_cde_threshold: bool, + table_name: str | None, + clear_table_name: bool, +) -> None: + """Reject args that are meaningless for the resolved event. + + Collects every stray arg into a single ``MCPUserError`` so the caller can fix + them all in one round-trip. The message names the relevant supported event + for each stray so the LLM knows where each arg actually applies. + """ + threshold_strays = [ + name + for name, supplied in ( + ("total_threshold", total_threshold is not None), + ("cde_threshold", cde_threshold is not None), + ("clear_total_threshold", clear_total_threshold), + ("clear_cde_threshold", clear_cde_threshold), + ) + if supplied + ] + table_strays = [ + name + for name, supplied in ( + ("table_name", table_name is not None), + ("clear_table_name", clear_table_name), + ) + if supplied + ] + + messages: list[str] = [] + if event is NotificationEvent.score_drop: + if trigger_on is not None: + messages.append( + f"`trigger_on` is not supported for event type `{event_label}` — thresholds drive the event." + ) + if table_strays: + joined = ", ".join(f"`{name}`" for name in table_strays) + messages.append( + f"{joined} not supported for event type `{event_label}`. " + "Only `Monitor Alert` notifications can be scoped to a table." + ) + else: + if threshold_strays: + joined = ", ".join(f"`{name}`" for name in threshold_strays) + messages.append( + f"{joined} not supported for event type `{event_label}`. " + "Only `Score Drop` notifications use score thresholds." + ) + if event is not NotificationEvent.monitor_run and table_strays: + joined = ", ".join(f"`{name}`" for name in table_strays) + messages.append( + f"{joined} not supported for event type `{event_label}`. " + "Only `Monitor Alert` notifications can be scoped to a table." + ) + + if messages: + raise MCPUserError(" ".join(messages)) + + +def _reject_set_and_clear_conflicts( + *, + total_threshold: float | None, + clear_total_threshold: bool, + cde_threshold: float | None, + clear_cde_threshold: bool, + table_name: str | None, + clear_table_name: bool, +) -> None: + """Reject any (set, clear) pair where the caller supplied both for the same field.""" + conflicts = [ + name + for name, set_supplied, clear_supplied in ( + ("total_threshold", total_threshold is not None, clear_total_threshold), + ("cde_threshold", cde_threshold is not None, clear_cde_threshold), + ("table_name", table_name is not None, clear_table_name), + ) + if set_supplied and clear_supplied + ] + if conflicts: + joined = ", ".join(f"`{name}`" for name in conflicts) + raise MCPUserError(f"{joined} cannot be both set and cleared in the same call.") + + +def _validate_score_drop_post_state( + notif: NotificationSettings, + *, + total_threshold: float | None, + cde_threshold: float | None, + clear_total_threshold: bool, + clear_cde_threshold: bool, +) -> None: + """Pre-empt model.save()'s "at least one threshold" invariant. + + Compute the effective threshold values that would result from applying the + pending change and reject up-front if both would be NULL. + """ + if clear_total_threshold: + effective_total = None + elif total_threshold is not None: + effective_total = total_threshold + else: + effective_total = notif.total_score_threshold + + if clear_cde_threshold: + effective_cde = None + elif cde_threshold is not None: + effective_cde = cde_threshold + else: + effective_cde = notif.cde_score_threshold + + if effective_total is None and effective_cde is None: + raise MCPUserError( + "At least one of `total_threshold` or `cde_threshold` must remain set " + "for a `Score Drop` notification." + ) + + +def _build_pending( + notif: NotificationSettings, + *, + enabled: bool | None, + recipients: list[str] | None, + trigger: object, + total_threshold: float | None, + cde_threshold: float | None, + clear_total_threshold: bool, + clear_cde_threshold: bool, + table_name: str | None, + clear_table_name: bool, +) -> dict[str, object]: + """Return only the changes that actually differ from the current state.""" + pending: dict[str, object] = {} + + if enabled is not None and notif.enabled != enabled: + pending["enabled"] = enabled + + if recipients is not None and list(notif.recipients or []) != recipients: + pending["recipients"] = recipients + + if trigger is not None and notif.trigger != trigger: + pending["trigger"] = trigger + + if clear_total_threshold and notif.total_score_threshold is not None: + pending["total_score_threshold"] = None + elif total_threshold is not None and notif.total_score_threshold != total_threshold: + pending["total_score_threshold"] = total_threshold + + if clear_cde_threshold and notif.cde_score_threshold is not None: + pending["cde_score_threshold"] = None + elif cde_threshold is not None and notif.cde_score_threshold != cde_threshold: + pending["cde_score_threshold"] = cde_threshold + + if clear_table_name and notif.table_name is not None: + pending["table_name"] = None + elif table_name is not None and notif.table_name != table_name: + pending["table_name"] = table_name + + return pending + + +def _snapshot_attr(notif: NotificationSettings, attr: str) -> object: + """Render a single attribute's current value in display form for the diff table.""" + if attr == "enabled": + return "Active" if notif.enabled else "Paused" + if attr == "recipients": + return ", ".join(notif.recipients or []) or None + if attr == "trigger": + return _label_for_trigger(notif.event, notif.trigger) + if attr == "total_score_threshold": + return _format_threshold(notif.total_score_threshold) + if attr == "cde_score_threshold": + return _format_threshold(notif.cde_score_threshold) + if attr == "table_name": + return notif.table_name or None + return None + + +def _label_for_trigger(event: NotificationEvent, trigger: object) -> str | None: + """Render the user-facing label for an in-memory trigger enum value.""" + if trigger is None: + return None + if event is NotificationEvent.test_run and isinstance(trigger, TestRunNotificationTrigger): + return format_notification_trigger(event, {"trigger": trigger.value}) + if event is NotificationEvent.profiling_run and isinstance(trigger, ProfilingRunNotificationTrigger): + return format_notification_trigger(event, {"trigger": trigger.value}) + if event is NotificationEvent.monitor_run and isinstance(trigger, MonitorNotificationTrigger): + return format_notification_trigger(event, {"trigger": trigger.value}) + return None + + +def _format_threshold(value: object) -> str | None: + """Render a stored Decimal threshold (or an in-memory float/int) as a display string.""" + if value is None: + return None + if isinstance(value, Decimal): + return str(value) + return str(value) + + +@with_database_session +@mcp_permission("edit") +def delete_notification(notification_id: str) -> str: + """Delete an email notification. + + Works on any notification, including ``Monitor Alert`` notifications — those are + created through monitor setup rather than this tool, but can be deleted here. + + Args: + notification_id: UUID of the notification, e.g. from ``list_notifications``. + """ + notif = resolve_notification(notification_id) + event_label = format_notification_event(notif.event) + + doc = MdDoc() + doc.heading(1, f"{event_label} Notification deleted") + doc.field("Notification ID", notif.id, code=True) + doc.field("Event Type", event_label) + doc.field("Project", notif.project_code, code=True) + _render_scope_fields(doc, notif) + + notif.delete() + + return doc.render() + + +def _render_one(notif: NotificationSettings) -> str: + doc = MdDoc() + event_label = format_notification_event(notif.event) + doc.heading(1, f"{event_label} Notification") + _render_notification_body(doc, notif) + return doc.render() + + +def _render_created(notif: NotificationSettings) -> str: + doc = MdDoc() + event_label = format_notification_event(notif.event) + doc.heading(1, f"{event_label} Notification created") + _render_notification_body(doc, notif) + return doc.render() + + +def _render_notification_body(doc: MdDoc, notif: NotificationSettings) -> None: + event_label = format_notification_event(notif.event) + status_word = "Active" if notif.enabled else "Paused" + + doc.heading(2, "Configuration") + doc.field("Notification ID", notif.id, code=True) + doc.field("Event Type", event_label) + doc.field("Status", status_word) + if trigger_label := format_notification_trigger(notif.event, notif.settings): + doc.field("Trigger", trigger_label) + if notif.event == NotificationEvent.score_drop: + total_threshold = (notif.settings or {}).get("total_threshold") + cde_threshold = (notif.settings or {}).get("cde_threshold") + if total_threshold is not None: + doc.field("Total Score Threshold", total_threshold) + if cde_threshold is not None: + doc.field("CDE Score Threshold", cde_threshold) + + doc.heading(2, "Scope") + doc.field("Project", notif.project_code, code=True) + _render_scope_fields(doc, notif) + + doc.heading(2, "Recipients") + if notif.recipients: + doc.bullets(list(notif.recipients)) + else: + doc.text("_No recipients configured._") + + +class _ScopeEntityKind(StrEnum): + SUITE = "suite" + TABLE_GROUP = "table_group" + SCORECARD = "scorecard" + + +@dataclass(frozen=True) +class _ScopeField: + label: str + id_attr: str + all_label: str + kind: _ScopeEntityKind + + +_SUITE_FIELD = _ScopeField("Test Suite", "test_suite_id", "All Test Suites", _ScopeEntityKind.SUITE) +_TABLE_GROUP_FIELD = _ScopeField("Table Group", "table_group_id", "All Table Groups", _ScopeEntityKind.TABLE_GROUP) +_SCORECARD_FIELD = _ScopeField("Scorecard", "score_definition_id", "All Scorecards", _ScopeEntityKind.SCORECARD) + +# Single source of truth: which scope entities (and labels) each event renders. +# Both the detail view (_render_scope_fields) and the list view (_scope_text) iterate this. +# Monitors are scoped to their table group only — the underlying monitor test suite is an +# internal detail that is never surfaced. An optional table narrows the scope further (see +# the monitor ``table_name`` handling in both renderers). +_SCOPE_FIELDS: dict[NotificationEvent, tuple[_ScopeField, ...]] = { + NotificationEvent.test_run: (_SUITE_FIELD,), + NotificationEvent.profiling_run: (_TABLE_GROUP_FIELD,), + NotificationEvent.score_drop: (_SCORECARD_FIELD,), + NotificationEvent.monitor_run: (_TABLE_GROUP_FIELD,), +} + + +def _render_scope_fields(doc: MdDoc, notif: NotificationSettings) -> None: + for field in _SCOPE_FIELDS.get(notif.event, ()): + entity_id = getattr(notif, field.id_attr) + name = _resolve_scope_name(field.kind, entity_id) + doc.field(field.label, _scope_value(name, entity_id, field.all_label)) + if notif.event == NotificationEvent.monitor_run and (table_name := (notif.settings or {}).get("table_name")): + doc.field("Table", table_name) + + +def _resolve_scope_name(kind: _ScopeEntityKind, entity_id: UUID | None) -> str | None: + if kind is _ScopeEntityKind.SUITE: + return _suite_name(entity_id) + if kind is _ScopeEntityKind.TABLE_GROUP: + return _table_group_name(entity_id) + return _scorecard_name(entity_id) + + +def _scope_value(name: str | None, entity_id: UUID | None, project_wide_label: str) -> str: + if entity_id is None: + return project_wide_label + display = name or str(entity_id) + return f"{display} ({MdDoc.code(str(entity_id))})" + + +def _suite_name(suite_id: UUID | None) -> str | None: + if suite_id is None: + return None + suite = TestSuite.get(suite_id) + return suite.test_suite if suite else None + + +def _table_group_name(tg_id: UUID | None) -> str | None: + if tg_id is None: + return None + tg = TableGroup.get(tg_id) + return tg.table_groups_name if tg else None + + +def _scorecard_name(score_id: UUID | None) -> str | None: + if score_id is None: + return None + sd = ScoreDefinition.get(str(score_id)) + return sd.name if sd else None + + +def _render( + rows: list[NotificationSummary], + total: int, + *, + page: int, + limit: int, + scope_label: str | None, +) -> str: + doc = MdDoc() + heading = f"Email Notifications — {scope_label}" if scope_label else "Email Notifications" + doc.heading(1, heading) + + if not rows: + doc.text("_No notifications match the supplied scope._") + return doc.render() + + if info := format_page_info(total, page, limit): + doc.text(info) + + suite_names = _batch_suite_names({r.test_suite_id for r in rows if r.test_suite_id}) + tg_names = _batch_table_group_names({r.table_group_id for r in rows if r.table_group_id}) + score_names = _batch_score_names({r.score_definition_id for r in rows if r.score_definition_id}) + + for r in rows: + status_word = "Active" if r.enabled else "Paused" + event_label = format_notification_event(r.event) + scope_text = _scope_text(r, suite_names, tg_names, score_names) + doc.heading(2, f"[{status_word}] {event_label} Notification — {scope_text}") + doc.field("Notification ID", r.id, code=True) + doc.field("Event Type", event_label) + doc.field("Status", status_word) + doc.field("Project", r.project_code, code=True) + doc.field("Scope", scope_text) + if trigger_label := format_notification_trigger(r.event, r.settings): + doc.field("Trigger", trigger_label) + if r.event == NotificationEvent.score_drop: + total_threshold = (r.settings or {}).get("total_threshold") + cde_threshold = (r.settings or {}).get("cde_threshold") + if total_threshold is not None: + doc.field("Total Score Threshold", total_threshold) + if cde_threshold is not None: + doc.field("CDE Score Threshold", cde_threshold) + doc.field("Recipients", ", ".join(r.recipients or []) or None) + + if footer := format_page_footer(total, page, limit): + doc.text(footer) + + return doc.render() + + +def _scope_text( + row: NotificationSummary, + suite_names: dict[UUID, str], + tg_names: dict[UUID, str], + score_names: dict[UUID, str], +) -> str: + batches = { + _ScopeEntityKind.SUITE: suite_names, + _ScopeEntityKind.TABLE_GROUP: tg_names, + _ScopeEntityKind.SCORECARD: score_names, + } + fields = _SCOPE_FIELDS.get(row.event, ()) + if not fields: + return "—" + # A project-wide entity reads as a bare label, e.g. "All Table Groups". + parts = [] + for field in fields: + entity_id = getattr(row, field.id_attr) + if entity_id is None: + parts.append(field.all_label) + else: + parts.append(f"{field.label}: {batches[field.kind].get(entity_id, str(entity_id))}") + if row.event == NotificationEvent.monitor_run and (table_name := (row.settings or {}).get("table_name")): + parts.append(f"Table: {table_name}") + return " · ".join(parts) + + +def _batch_suite_names(suite_ids: set[UUID]) -> dict[UUID, str]: + if not suite_ids: + return {} + return {s.id: s.test_suite for s in TestSuite.select_minimal_where(TestSuite.id.in_(list(suite_ids)))} + + +def _batch_table_group_names(tg_ids: set[UUID]) -> dict[UUID, str]: + if not tg_ids: + return {} + return {tg.id: tg.table_groups_name for tg in TableGroup.select_minimal_where(TableGroup.id.in_(list(tg_ids)))} + + +def _batch_score_names(score_ids: set[UUID]) -> dict[UUID, str]: + if not score_ids: + return {} + return ScoreDefinition.names_by_id(score_ids) diff --git a/tests/unit/common/models/test_notification_settings.py b/tests/unit/common/models/test_notification_settings.py new file mode 100644 index 00000000..1a46817d --- /dev/null +++ b/tests/unit/common/models/test_notification_settings.py @@ -0,0 +1,139 @@ +"""Tests for ``NotificationSettings`` query semantics. + +The listing surface (``list_for_test_suite`` / ``list_for_table_group`` / +``list_for_score_definition``) must use strict equality on the scope column — +no ``IS NULL`` wildcard. The firing-pipeline surface (``_base_select_query``) +must keep the ``IS NULL`` wildcard so a project-wide notification matches +events on any child entity. +""" + +from unittest.mock import patch +from uuid import UUID, uuid4 + +import pytest + +from testgen.common.models.notification_settings import NotificationSettings, is_valid_email + +pytestmark = pytest.mark.unit + + +# ─── Shared email validation helper ─────────────────────────────────── + + +@pytest.mark.parametrize("addr", [ + "alice@example.com", + "a.b+tag@sub.domain.co", + "x_y%z@host-name.io", +]) +def test_is_valid_email_accepts_well_formed(addr): + assert is_valid_email(addr) is True + + +@pytest.mark.parametrize("addr", [ + "no-at-sign", + "spaces in@here.com", + "nodot@nope", + "@nodomain.com", + "trailing@dot.", + "", +]) +def test_is_valid_email_rejects_malformed(addr): + assert is_valid_email(addr) is False + + +def _captured_list_sql(method_name: str, *args, **kwargs) -> str: + """Invoke a ``list_for_*`` classmethod and compile the query it passes to ``_paginate``.""" + with patch.object(NotificationSettings, "_paginate", return_value=([], 0)) as mock_paginate: + getattr(NotificationSettings, method_name)(*args, **kwargs) + query = mock_paginate.call_args.args[0] + return str(query.compile(compile_kwargs={"literal_binds": True})) + + +def _uuid_in_sql(value: UUID, sql: str) -> bool: + """SQLAlchemy literal_binds compiles UUIDs as 32-char hex (no dashes); accept either.""" + return str(value) in sql or value.hex in sql + + +# ─── Listing surface — strict equality, no IS NULL ──────────────────── + + +def test_list_for_test_suite_filters_by_strict_equality_only(): + suite_id = uuid4() + sql = _captured_list_sql("list_for_test_suite", suite_id) + + assert "IS NULL" not in sql.upper(), ( + "list_for_test_suite must not surface rows where test_suite_id IS NULL — " + "they may be unrelated event types whose scope column happens to be null." + ) + assert "test_suite_id" in sql + assert _uuid_in_sql(suite_id, sql) + + +def test_list_for_table_group_filters_by_strict_equality_only(): + table_group_id = uuid4() + sql = _captured_list_sql("list_for_table_group", table_group_id) + + assert "IS NULL" not in sql.upper(), ( + "list_for_table_group must not surface rows where table_group_id IS NULL — " + "they may be unrelated event types whose scope column happens to be null." + ) + assert "table_group_id" in sql + assert _uuid_in_sql(table_group_id, sql) + + +def test_list_for_score_definition_filters_by_strict_equality_only(): + score_definition_id = uuid4() + sql = _captured_list_sql("list_for_score_definition", score_definition_id) + + assert "IS NULL" not in sql.upper(), ( + "list_for_score_definition must not surface rows where score_definition_id IS NULL — " + "they may be unrelated event types whose scope column happens to be null." + ) + assert "score_definition_id" in sql + assert _uuid_in_sql(score_definition_id, sql) + + +# ─── Firing pipeline — IS NULL preserved (regression guard) ─────────── +# +# `_base_select_query` is consumed by the notification firing pipeline, where +# a notification with `_id IS NULL` legitimately means "fires for any +# child of that type in the same project." Leaving this branch alone is the +# whole reason the listing-side fix is scoped to the `list_for_*` helpers. + + +def _firing_query_sql(**kwargs) -> str: + query = NotificationSettings._base_select_query(**kwargs) + return str(query.compile(compile_kwargs={"literal_binds": True})) + + +def test_base_select_query_test_suite_keeps_null_wildcard(): + suite_id = uuid4() + sql = _firing_query_sql(test_suite_id=suite_id) + + assert "IS NULL" in sql.upper(), ( + "_base_select_query is used by the firing pipeline, which needs " + "test_suite_id IS NULL to mean 'fires for any suite in the project'." + ) + assert _uuid_in_sql(suite_id, sql) + + +def test_base_select_query_table_group_keeps_null_wildcard(): + table_group_id = uuid4() + sql = _firing_query_sql(table_group_id=table_group_id) + + assert "IS NULL" in sql.upper(), ( + "_base_select_query is used by the firing pipeline, which needs " + "table_group_id IS NULL to mean 'fires for any table group in the project'." + ) + assert _uuid_in_sql(table_group_id, sql) + + +def test_base_select_query_score_definition_keeps_null_wildcard(): + score_definition_id = uuid4() + sql = _firing_query_sql(score_definition_id=score_definition_id) + + assert "IS NULL" in sql.upper(), ( + "_base_select_query is used by the firing pipeline, which needs " + "score_definition_id IS NULL to mean 'fires for any scorecard in the project'." + ) + assert _uuid_in_sql(score_definition_id, sql) diff --git a/tests/unit/common/models/test_score_definition.py b/tests/unit/common/models/test_score_definition.py index 2be46506..7978448b 100644 --- a/tests/unit/common/models/test_score_definition.py +++ b/tests/unit/common/models/test_score_definition.py @@ -432,3 +432,42 @@ def test_list_for_project_count_null_returns_zero(mock_session_fn): items, total = ScoreDefinition.list_for_project("demo") assert items == [] assert total == 0 + + +# ─── names_by_id — single batched lookup, no N+1 ────────────────────── + + +@patch("testgen.common.models.scores.get_current_session") +def test_names_by_id_returns_id_to_name_mapping(mock_session_fn): + id_a, id_b = uuid4(), uuid4() + mock_result = MagicMock() + mock_result.all.return_value = [_row(id_a, "Card A", None), _row(id_b, "Card B", None)] + mock_session_fn.return_value.execute.return_value = mock_result + + out = ScoreDefinition.names_by_id([id_a, id_b]) + + assert out == {id_a: "Card A", id_b: "Card B"} + + +@patch("testgen.common.models.scores.get_current_session") +def test_names_by_id_empty_input_skips_query(mock_session_fn): + out = ScoreDefinition.names_by_id([]) + + assert out == {} + mock_session_fn.return_value.execute.assert_not_called() + + +@patch("testgen.common.models.scores.get_current_session") +def test_names_by_id_uses_single_in_query(mock_session_fn): + """One IN query for all IDs — not a per-ID lookup (N+1).""" + ids = [uuid4(), uuid4(), uuid4()] + mock_result = MagicMock() + mock_result.all.return_value = [] + mock_session_fn.return_value.execute.return_value = mock_result + + ScoreDefinition.names_by_id(ids) + + assert mock_session_fn.return_value.execute.call_count == 1 + args, _ = mock_session_fn.return_value.execute.call_args + sql = str(args[0].compile(compile_kwargs={"literal_binds": True})) + assert " IN (" in sql.upper() diff --git a/tests/unit/mcp/test_tools_notifications.py b/tests/unit/mcp/test_tools_notifications.py new file mode 100644 index 00000000..9ce2eccd --- /dev/null +++ b/tests/unit/mcp/test_tools_notifications.py @@ -0,0 +1,2318 @@ +from decimal import Decimal +from unittest.mock import MagicMock, patch +from uuid import UUID, uuid4 + +import pytest + +from testgen.common.models.notification_settings import ( + MonitorNotificationTrigger, + NotificationEvent, + NotificationSummary, + ProfilingRunNotificationTrigger, + TestRunNotificationTrigger, +) +from testgen.mcp.exceptions import MCPResourceNotAccessible, MCPUserError +from testgen.mcp.permissions import ProjectPermissions +from testgen.mcp.tools.common import ( + MONITOR_TRIGGER_LABEL_TO_INTERNAL, + NOTIFICATION_EVENT_LABEL_TO_INTERNAL, + PROFILING_RUN_TRIGGER_LABEL_TO_INTERNAL, + TEST_RUN_TRIGGER_LABEL_TO_INTERNAL, + format_notification_event, + format_notification_trigger, +) + +pytestmark = pytest.mark.unit + + +# --- Helpers --- + + +def _patch_perms(allowed=("demo",), memberships=None): + memberships = memberships or dict.fromkeys(allowed, "role_a") + return patch( + "testgen.mcp.permissions._compute_project_permissions", + return_value=ProjectPermissions( + memberships=memberships, permission="view", username="test_user", + ), + ) + + +def _summary( + *, + event: NotificationEvent, + enabled: bool = True, + project_code: str = "demo", + recipients=("alice@example.com",), + test_suite_id: UUID | None = None, + table_group_id: UUID | None = None, + score_definition_id: UUID | None = None, + settings: dict | None = None, +) -> NotificationSummary: + return NotificationSummary( + id=uuid4(), + project_code=project_code, + event=event, + enabled=enabled, + recipients=list(recipients), + test_suite_id=test_suite_id, + table_group_id=table_group_id, + score_definition_id=score_definition_id, + settings=settings or {}, + ) + + +def _patch_list_for_projects(rows, total): + return patch( + "testgen.common.models.notification_settings.NotificationSettings.list_for_projects", + return_value=(rows, total), + ) + + +def _patch_list_for_test_suite(rows, total): + return patch( + "testgen.common.models.notification_settings.NotificationSettings.list_for_test_suite", + return_value=(rows, total), + ) + + +def _patch_list_for_table_group(rows, total): + return patch( + "testgen.common.models.notification_settings.NotificationSettings.list_for_table_group", + return_value=(rows, total), + ) + + +def _patch_list_for_score_definition(rows, total): + return patch( + "testgen.common.models.notification_settings.NotificationSettings.list_for_score_definition", + return_value=(rows, total), + ) + + +def _patch_no_resolve_lookups(): + """Make the batch-name helpers return empty dicts so tests don't need TestSuite/TableGroup mocks + unless they care about scope-name rendering. + """ + return patch.multiple( + "testgen.mcp.tools.notifications", + _batch_suite_names=MagicMock(return_value={}), + _batch_table_group_names=MagicMock(return_value={}), + _batch_score_names=MagicMock(return_value={}), + ) + + +# --- format helpers --- + + +def test_format_notification_event_round_trip(): + """Every NotificationEvent has a stable display label.""" + seen_labels = set() + for event in NotificationEvent: + label = format_notification_event(event) + seen_labels.add(label) + # Round-trip the label back to the internal enum. + assert NOTIFICATION_EVENT_LABEL_TO_INTERNAL[ + type(next(iter(NOTIFICATION_EVENT_LABEL_TO_INTERNAL)))(label) + ] is event + assert seen_labels == {"Test Run", "Profiling Run", "Score Drop", "Monitor Alert"} + + +def test_format_notification_event_accepts_raw_string(): + assert format_notification_event("test_run") == "Test Run" + + +def test_format_notification_trigger_test_run_labels(): + for trigger, label_enum in { + TestRunNotificationTrigger.always: "Always", + TestRunNotificationTrigger.on_failures: "On test failures", + TestRunNotificationTrigger.on_warnings: "On test failures and warnings", + TestRunNotificationTrigger.on_changes: "On new test failures and warnings", + }.items(): + assert ( + format_notification_trigger(NotificationEvent.test_run, {"trigger": trigger.value}) + == label_enum + ) + + +def test_format_notification_trigger_profiling_labels(): + assert ( + format_notification_trigger(NotificationEvent.profiling_run, {"trigger": "always"}) + == "Always" + ) + assert ( + format_notification_trigger(NotificationEvent.profiling_run, {"trigger": "on_changes"}) + == "On new hygiene issues" + ) + + +def test_format_notification_trigger_monitor_label(): + assert ( + format_notification_trigger(NotificationEvent.monitor_run, {"trigger": "on_anomalies"}) + == "On anomalies" + ) + + +def test_format_notification_trigger_score_drop_returns_none(): + assert format_notification_trigger(NotificationEvent.score_drop, {"total_threshold": "95.0"}) is None + + +def test_format_notification_trigger_missing_settings_returns_none(): + assert format_notification_trigger(NotificationEvent.test_run, None) is None + assert format_notification_trigger(NotificationEvent.test_run, {}) is None + + +def test_trigger_label_to_internal_dicts_cover_every_internal_enum(): + """No internal enum value should be missing a display label — both directions are total.""" + assert set(TEST_RUN_TRIGGER_LABEL_TO_INTERNAL.values()) == set(TestRunNotificationTrigger) + assert set(PROFILING_RUN_TRIGGER_LABEL_TO_INTERNAL.values()) == set(ProfilingRunNotificationTrigger) + assert set(MONITOR_TRIGGER_LABEL_TO_INTERNAL.values()) == set(MonitorNotificationTrigger) + + +def test_scope_fields_cover_every_event(): + """Every notification event must have a scope-field descriptor — no event can be + added without declaring which scope entities (and labels) it renders. + """ + from testgen.mcp.tools.notifications import _SCOPE_FIELDS + + assert set(_SCOPE_FIELDS) == set(NotificationEvent) + + +# --- Argument validation --- + + +def test_list_notifications_rejects_two_scope_args(db_session_mock): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), pytest.raises(MCPUserError, match="at most one"): + list_notifications(project_code="demo", test_suite_id=str(uuid4())) + + +def test_list_notifications_rejects_three_scope_args(db_session_mock): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), pytest.raises(MCPUserError, match="at most one"): + list_notifications(test_suite_id=str(uuid4()), table_group_id=str(uuid4()), scorecard_id=str(uuid4())) + + +@pytest.mark.parametrize("page,limit", [(0, 10), (1, 0), (1, 201)]) +def test_list_notifications_rejects_invalid_pagination(db_session_mock, page, limit): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), pytest.raises(MCPUserError): + list_notifications(page=page, limit=limit) + + +def test_list_notifications_invalid_test_suite_uuid(db_session_mock): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), pytest.raises(MCPUserError, match="not a valid UUID"): + list_notifications(test_suite_id="not-a-uuid") + + +def test_list_notifications_invalid_table_group_uuid(db_session_mock): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), pytest.raises(MCPUserError, match="not a valid UUID"): + list_notifications(table_group_id="not-a-uuid") + + +def test_list_notifications_invalid_scorecard_uuid(db_session_mock): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), pytest.raises(MCPUserError, match="not a valid UUID"): + list_notifications(scorecard_id="not-a-uuid") + + +def test_list_notifications_rejects_inaccessible_project(db_session_mock): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(allowed=("demo",)), pytest.raises( + MCPResourceNotAccessible, match=r"Project.*forbidden_proj" + ): + list_notifications(project_code="forbidden_proj") + + +@patch("testgen.mcp.tools.common.TestSuite.get") +def test_list_notifications_rejects_inaccessible_test_suite(mock_suite_get, db_session_mock): + mock_suite_get.return_value = None + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), pytest.raises(MCPResourceNotAccessible, match="Test suite"): + list_notifications(test_suite_id=str(uuid4())) + + +@patch("testgen.mcp.tools.common.TableGroup.get") +def test_list_notifications_rejects_inaccessible_table_group(mock_tg_get, db_session_mock): + mock_tg_get.return_value = None + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), pytest.raises(MCPResourceNotAccessible, match="Table group"): + list_notifications(table_group_id=str(uuid4())) + + +@patch("testgen.mcp.tools.common.ScoreDefinition.get") +def test_list_notifications_rejects_inaccessible_scorecard(mock_score_get, db_session_mock): + mock_score_get.return_value = None + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), pytest.raises(MCPResourceNotAccessible, match="Scorecard"): + list_notifications(scorecard_id=str(uuid4())) + + +# --- Listing & dispatch --- + + +def test_list_notifications_no_scope_uses_allowed_projects(db_session_mock): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(allowed=("demo", "other")), _patch_list_for_projects([], 0) as mock_list: + list_notifications() + + args, kwargs = mock_list.call_args + assert sorted(args[0]) == ["demo", "other"] + assert kwargs["page"] == 1 + assert kwargs["limit"] == 50 + + +def test_list_notifications_project_scope_dispatches_to_list_for_projects(db_session_mock): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_projects([], 0) as mock_list: + list_notifications(project_code="demo") + + args, _ = mock_list.call_args + assert args[0] == ["demo"] + + +@patch("testgen.mcp.tools.common.TestSuite.get") +def test_list_notifications_test_suite_scope_dispatches_to_list_for_test_suite( + mock_suite_get, db_session_mock, +): + suite_uuid = uuid4() + suite_mock = MagicMock() + suite_mock.id = suite_uuid + suite_mock.test_suite = "orders_v1" + suite_mock.project_code = "demo" + mock_suite_get.return_value = suite_mock + + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_test_suite([], 0) as mock_list: + list_notifications(test_suite_id=str(suite_uuid)) + + args, _ = mock_list.call_args + assert args[0] == suite_uuid + + +@patch("testgen.mcp.tools.common.TableGroup.get") +def test_list_notifications_table_group_scope_dispatches_to_list_for_table_group( + mock_tg_get, db_session_mock, +): + tg_uuid = uuid4() + tg_mock = MagicMock() + tg_mock.id = tg_uuid + tg_mock.table_groups_name = "prod_warehouse" + tg_mock.project_code = "demo" + mock_tg_get.return_value = tg_mock + + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_table_group([], 0) as mock_list: + list_notifications(table_group_id=str(tg_uuid)) + + args, _ = mock_list.call_args + assert args[0] == tg_uuid + + +@patch("testgen.mcp.tools.common.ScoreDefinition.get") +def test_list_notifications_scorecard_scope_dispatches_to_list_for_score_definition( + mock_score_get, db_session_mock, +): + sd_uuid = uuid4() + sd_mock = MagicMock() + sd_mock.id = sd_uuid + sd_mock.name = "Daily Orders Health" + sd_mock.project_code = "demo" + mock_score_get.return_value = sd_mock + + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_score_definition([], 0) as mock_list: + list_notifications(scorecard_id=str(sd_uuid)) + + args, _ = mock_list.call_args + assert args[0] == sd_uuid + + +# --- Rendering --- + + +def test_list_notifications_empty_renders_friendly_message(db_session_mock): + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_projects([], 0): + out = list_notifications() + + assert "# Email Notifications" in out + assert "_No notifications match the supplied scope._" in out + + +def test_list_notifications_renders_test_run_with_suite_scope(db_session_mock): + suite_id = uuid4() + row = _summary( + event=NotificationEvent.test_run, + test_suite_id=suite_id, + settings={"trigger": "on_failures"}, + recipients=("alice@example.com", "bob@example.com"), + ) + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_projects([row], 1), patch( + "testgen.mcp.tools.notifications._batch_suite_names", + return_value={suite_id: "orders_v1"}, + ), patch( + "testgen.mcp.tools.notifications._batch_table_group_names", return_value={}, + ), patch( + "testgen.mcp.tools.notifications._batch_score_names", return_value={}, + ): + out = list_notifications() + + assert "[Active] Test Run Notification" in out + assert "Test Suite: orders_v1" in out + assert "On test failures" in out + assert "alice@example.com, bob@example.com" in out + # No internal code leakage + assert "test_run" not in out + assert "on_failures" not in out + + +def test_list_notifications_renders_profiling_run_project_wide(db_session_mock): + row = _summary( + event=NotificationEvent.profiling_run, + enabled=False, + table_group_id=None, + settings={"trigger": "on_changes"}, + recipients=("ops@example.com",), + ) + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_projects([row], 1), _patch_no_resolve_lookups(): + out = list_notifications() + + assert "[Paused] Profiling Run Notification" in out + assert "All Table Groups" in out + assert "(project-wide)" not in out + assert "On new hygiene issues" in out + assert "Status:** Paused" in out + + +def test_list_notifications_renders_score_drop_thresholds(db_session_mock): + sd_id = uuid4() + row = _summary( + event=NotificationEvent.score_drop, + score_definition_id=sd_id, + settings={"total_threshold": "95.0", "cde_threshold": "90.0"}, + recipients=("alerts@example.com",), + ) + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_projects([row], 1), patch( + "testgen.mcp.tools.notifications._batch_score_names", + return_value={sd_id: "Daily Orders Health"}, + ), patch( + "testgen.mcp.tools.notifications._batch_suite_names", return_value={}, + ), patch( + "testgen.mcp.tools.notifications._batch_table_group_names", return_value={}, + ): + out = list_notifications() + + assert "Score Drop Notification" in out + assert "Scorecard: Daily Orders Health" in out + assert "Total Score Threshold:** 95.0" in out + assert "CDE Score Threshold:** 90.0" in out + # Score Drop has no trigger label + assert "Trigger:**" not in out + + +def test_list_notifications_renders_score_drop_one_threshold_only(db_session_mock): + sd_id = uuid4() + row = _summary( + event=NotificationEvent.score_drop, + score_definition_id=sd_id, + settings={"total_threshold": "95.0", "cde_threshold": None}, + ) + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_projects([row], 1), patch( + "testgen.mcp.tools.notifications._batch_score_names", + return_value={sd_id: "Card"}, + ), patch( + "testgen.mcp.tools.notifications._batch_suite_names", return_value={}, + ), patch( + "testgen.mcp.tools.notifications._batch_table_group_names", return_value={}, + ): + out = list_notifications() + + assert "Total Score Threshold:** 95.0" in out + assert "CDE Score Threshold" not in out + + +def test_list_notifications_renders_monitor_run_scope(db_session_mock): + tg_id = uuid4() + suite_id = uuid4() + row = _summary( + event=NotificationEvent.monitor_run, + table_group_id=tg_id, + test_suite_id=suite_id, + settings={"trigger": "on_anomalies", "table_name": "orders"}, + recipients=("monitor-alerts@example.com",), + ) + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_projects([row], 1), patch( + "testgen.mcp.tools.notifications._batch_suite_names", + return_value={suite_id: "monitors_v2"}, + ), patch( + "testgen.mcp.tools.notifications._batch_table_group_names", + return_value={tg_id: "prod_warehouse"}, + ), patch( + "testgen.mcp.tools.notifications._batch_score_names", return_value={}, + ): + out = list_notifications() + + assert "Monitor Alert Notification" in out + assert "Table Group: prod_warehouse" in out + assert "Table: orders" in out + assert "On anomalies" in out + # The monitor's internal test suite is never exposed — monitors are scoped to the table group. + assert "Test Suite" not in out + assert "monitors_v2" not in out + + +def test_list_notifications_pagination_renders_info_and_footer(db_session_mock): + rows = [ + _summary(event=NotificationEvent.test_run, settings={"trigger": "always"}) for _ in range(3) + ] + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(), _patch_list_for_projects(rows, 25), _patch_no_resolve_lookups(): + out = list_notifications(page=1, limit=3) + + # format_page_info emits an en-dash (\u2013) between start and end. + assert "Showing 1\u20133 of 25" in out + assert "Use `page=2` for more" in out + + +def test_list_notifications_passes_allowed_codes_only(db_session_mock): + """Even with no scope arg, the dispatch only sees the caller's allowed projects.""" + from testgen.mcp.tools.notifications import list_notifications + + with _patch_perms(allowed=("alpha", "beta")), _patch_list_for_projects([], 0) as mock_list: + list_notifications() + args, _ = mock_list.call_args + assert "alpha" in args[0] + assert "beta" in args[0] + assert "gamma" not in args[0] + + +# --- get_notification --- + + +def _notif_mock( + *, + event: NotificationEvent, + enabled: bool = True, + project_code: str = "demo", + recipients=("alice@example.com",), + test_suite_id: UUID | None = None, + table_group_id: UUID | None = None, + score_definition_id: UUID | None = None, + settings: dict | None = None, +) -> MagicMock: + """Build a mock that quacks like a polymorphic ``NotificationSettings`` ORM row.""" + notif = MagicMock() + notif.id = uuid4() + notif.event = event + notif.enabled = enabled + notif.project_code = project_code + notif.recipients = list(recipients) + notif.test_suite_id = test_suite_id + notif.table_group_id = table_group_id + notif.score_definition_id = score_definition_id + notif.settings = settings or {} + return notif + + +def _patch_notification_get(return_value): + return patch( + "testgen.mcp.tools.common.NotificationSettings.get", + return_value=return_value, + ) + + +def _patch_get_notification_scope_lookups( + *, suite_name: str | None = None, tg_name: str | None = None, score_name: str | None = None, +): + """Patch the per-entity scope-name lookups used by ``_render_one``. + + Each patched ``.get`` returns a MagicMock with the supplied name attribute (or ``None``). + Tests that don't care about scope names pass nothing. + """ + suite_mock = None + if suite_name is not None: + suite_mock = MagicMock() + suite_mock.test_suite = suite_name + tg_mock = None + if tg_name is not None: + tg_mock = MagicMock() + tg_mock.table_groups_name = tg_name + score_mock = None + if score_name is not None: + score_mock = MagicMock() + score_mock.name = score_name + + return patch.multiple( + "testgen.mcp.tools.notifications", + TestSuite=MagicMock(get=MagicMock(return_value=suite_mock)), + TableGroup=MagicMock(get=MagicMock(return_value=tg_mock)), + ScoreDefinition=MagicMock(get=MagicMock(return_value=score_mock)), + ) + + +def test_get_notification_invalid_uuid(db_session_mock): + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="not a valid UUID"): + get_notification(notification_id="not-a-uuid") + + +def test_get_notification_missing_returns_unified_not_accessible(db_session_mock): + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(), _patch_notification_get(None), pytest.raises( + MCPResourceNotAccessible, match="Notification", + ): + get_notification(notification_id=str(uuid4())) + + +def test_get_notification_inaccessible_project_returns_unified_not_accessible(db_session_mock): + """``NotificationSettings.get`` returns ``None`` when the project filter excludes the row. + + Both the missing-id and the wrong-project paths must surface as the same error + so callers can't enumerate notifications across projects they don't own. + """ + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(allowed=("demo",)), _patch_notification_get(None), pytest.raises( + MCPResourceNotAccessible, match="Notification", + ): + get_notification(notification_id=str(uuid4())) + + +def test_get_notification_test_run_with_suite_renders_all_sections(db_session_mock): + suite_id = uuid4() + notif = _notif_mock( + event=NotificationEvent.test_run, + test_suite_id=suite_id, + settings={"trigger": "on_failures"}, + recipients=("alice@example.com", "bob@example.com"), + ) + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + suite_name="orders_v1", + ): + out = get_notification(notification_id=str(notif.id)) + + # H1 + section headings + assert "# Test Run Notification" in out + assert "## Configuration" in out + assert "## Scope" in out + assert "## Recipients" in out + # Configuration fields + assert "Event Type:** Test Run" in out + assert "Status:** Active" in out + assert "Trigger:** On test failures" in out + # Scope surfaces suite name + id for chaining + assert "Project:** `demo`" in out + assert "Test Suite:** orders_v1" in out + assert f"`{suite_id}`" in out + # Recipients as bullets + assert "- alice@example.com" in out + assert "- bob@example.com" in out + # No internal code leakage + assert "test_run" not in out + assert "on_failures" not in out + + +def test_get_notification_test_run_project_wide_omits_suite_id(db_session_mock): + notif = _notif_mock( + event=NotificationEvent.test_run, + test_suite_id=None, + settings={"trigger": "always"}, + ) + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = get_notification(notification_id=str(notif.id)) + + assert "Test Suite:** All Test Suites" in out + # Project-wide notifications have no parent id to surface. + assert "(`" not in out.split("## Scope")[1] + + +def test_get_notification_profiling_run_with_table_group(db_session_mock): + tg_id = uuid4() + notif = _notif_mock( + event=NotificationEvent.profiling_run, + table_group_id=tg_id, + settings={"trigger": "on_changes"}, + ) + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + tg_name="prod_warehouse", + ): + out = get_notification(notification_id=str(notif.id)) + + assert "# Profiling Run Notification" in out + assert "Trigger:** On new hygiene issues" in out + assert "Table Group:** prod_warehouse" in out + assert f"`{tg_id}`" in out + + +def test_get_notification_score_drop_renders_thresholds_and_omits_trigger(db_session_mock): + sd_id = uuid4() + notif = _notif_mock( + event=NotificationEvent.score_drop, + score_definition_id=sd_id, + settings={"total_threshold": "85.0", "cde_threshold": "90.0"}, + ) + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + score_name="Daily Orders Health", + ): + out = get_notification(notification_id=str(notif.id)) + + assert "# Score Drop Notification" in out + assert "Total Score Threshold:** 85.0" in out + assert "CDE Score Threshold:** 90.0" in out + assert "Trigger:**" not in out + assert "Scorecard:** Daily Orders Health" in out + + +def test_get_notification_score_drop_only_total_threshold(db_session_mock): + notif = _notif_mock( + event=NotificationEvent.score_drop, + score_definition_id=uuid4(), + settings={"total_threshold": "85.0", "cde_threshold": None}, + ) + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + score_name="Card", + ): + out = get_notification(notification_id=str(notif.id)) + + assert "Total Score Threshold:** 85.0" in out + assert "CDE Score Threshold" not in out + + +def test_get_notification_monitor_run_renders_table_group_and_table(db_session_mock): + tg_id = uuid4() + suite_id = uuid4() + notif = _notif_mock( + event=NotificationEvent.monitor_run, + table_group_id=tg_id, + test_suite_id=suite_id, + settings={"trigger": "on_anomalies", "table_name": "orders"}, + ) + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + suite_name="monitors_v2", tg_name="prod_warehouse", + ): + out = get_notification(notification_id=str(notif.id)) + + assert "# Monitor Alert Notification" in out + assert "Trigger:** On anomalies" in out + # The table is part of the monitor's scope, rendered as "Table" (not a "Filtered Table" filter). + assert "Table:** orders" in out + assert "Table Group:** prod_warehouse" in out + assert f"`{tg_id}`" in out + # The internal monitor test suite is never exposed. + assert "Test Suite" not in out + assert "monitors_v2" not in out + assert f"`{suite_id}`" not in out + + +def test_get_notification_paused_renders_status_paused(db_session_mock): + notif = _notif_mock( + event=NotificationEvent.test_run, + enabled=False, + test_suite_id=uuid4(), + settings={"trigger": "always"}, + ) + from testgen.mcp.tools.notifications import get_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + suite_name="some_suite", + ): + out = get_notification(notification_id=str(notif.id)) + + assert "Status:** Paused" in out + + +# --------------------------------------------------------------------------- +# create_notification +# --------------------------------------------------------------------------- + + +def _make_create_suite(name="orders_v1", project_code="demo"): + suite = MagicMock() + suite.id = uuid4() + suite.test_suite = name + suite.project_code = project_code + suite.is_monitor = False + return suite + + +def _make_create_table_group(name="prod_warehouse", project_code="demo"): + tg = MagicMock() + tg.id = uuid4() + tg.table_groups_name = name + tg.project_code = project_code + return tg + + +def _make_create_scorecard(name="Daily Orders Health", project_code="demo"): + sd = MagicMock() + sd.id = uuid4() + sd.name = name + sd.project_code = project_code + return sd + + +def _make_saved_notif( + *, + event: NotificationEvent, + project_code: str = "demo", + recipients=("alice@example.com",), + test_suite_id: UUID | None = None, + table_group_id: UUID | None = None, + score_definition_id: UUID | None = None, + settings: dict | None = None, + enabled: bool = True, +) -> MagicMock: + """Mock that quacks like the polymorphic ``NotificationSettings`` row returned by ``.create()``.""" + notif = MagicMock() + notif.id = uuid4() + notif.event = event + notif.enabled = enabled + notif.project_code = project_code + notif.recipients = list(recipients) + notif.test_suite_id = test_suite_id + notif.table_group_id = table_group_id + notif.score_definition_id = score_definition_id + notif.settings = settings or {} + return notif + + +# --- Happy paths --- + + +@patch("testgen.mcp.tools.notifications.TestRunNotificationSettings") +@patch("testgen.mcp.tools.notifications.resolve_test_suite") +def test_create_notification_test_run_happy_path(mock_resolve_suite, mock_factory, db_session_mock): + suite = _make_create_suite(name="orders_v1") + mock_resolve_suite.return_value = suite + saved = _make_saved_notif( + event=NotificationEvent.test_run, + test_suite_id=suite.id, + settings={"trigger": "on_failures"}, + recipients=("alice@example.com", "bob@example.com"), + ) + mock_factory.create.return_value = saved + + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), _patch_get_notification_scope_lookups(suite_name="orders_v1"): + out = create_notification( + event_type="Test Run", + recipients=["alice@example.com", "bob@example.com"], + test_suite_id=str(suite.id), + trigger_on="On test failures", + ) + + mock_factory.create.assert_called_once_with( + project_code="demo", + test_suite_id=suite.id, + recipients=["alice@example.com", "bob@example.com"], + trigger=TestRunNotificationTrigger.on_failures, + ) + # Confirmation heading + assert "created" in out.lower() + # Display labels, not internal codes + assert "Test Run" in out + assert "On test failures" in out + assert "test_run" not in out + assert "on_failures" not in out + # Followable-IDs surface + assert f"`{saved.id}`" in out + # Recipients rendered + assert "alice@example.com" in out + assert "bob@example.com" in out + # Scope name surfaced + assert "orders_v1" in out + + +@patch("testgen.mcp.tools.notifications.ProfilingRunNotificationSettings") +@patch("testgen.mcp.tools.notifications.resolve_table_group") +def test_create_notification_profiling_run_happy_path(mock_resolve_tg, mock_factory, db_session_mock): + tg = _make_create_table_group(name="prod_warehouse") + mock_resolve_tg.return_value = tg + saved = _make_saved_notif( + event=NotificationEvent.profiling_run, + table_group_id=tg.id, + settings={"trigger": "on_changes"}, + ) + mock_factory.create.return_value = saved + + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), _patch_get_notification_scope_lookups(tg_name="prod_warehouse"): + out = create_notification( + event_type="Profiling Run", + recipients=["ops@example.com"], + table_group_id=str(tg.id), + trigger_on="On new hygiene issues", + ) + + mock_factory.create.assert_called_once_with( + project_code="demo", + table_group_id=tg.id, + recipients=["ops@example.com"], + trigger=ProfilingRunNotificationTrigger.on_changes, + ) + assert "Profiling Run" in out + assert "On new hygiene issues" in out + assert "prod_warehouse" in out + assert f"`{saved.id}`" in out + # No internal code leakage + assert "profiling_run" not in out + assert "on_changes" not in out + + +@patch("testgen.mcp.tools.notifications.ScoreDropNotificationSettings") +@patch("testgen.mcp.tools.notifications.resolve_scorecard") +def test_create_notification_score_drop_happy_path_both_thresholds( + mock_resolve_sc, + mock_factory, + db_session_mock, +): + scorecard = _make_create_scorecard(name="Daily Orders Health") + mock_resolve_sc.return_value = scorecard + saved = _make_saved_notif( + event=NotificationEvent.score_drop, + score_definition_id=scorecard.id, + settings={"total_threshold": "85.0", "cde_threshold": "90.0"}, + recipients=("alerts@example.com",), + ) + mock_factory.create.return_value = saved + + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), _patch_get_notification_scope_lookups(score_name="Daily Orders Health"): + out = create_notification( + event_type="Score Drop", + recipients=["alerts@example.com"], + scorecard_id=str(scorecard.id), + total_threshold=85, + cde_threshold=90, + ) + + mock_factory.create.assert_called_once_with( + project_code="demo", + score_definition_id=scorecard.id, + recipients=["alerts@example.com"], + total_score_threshold=85, + cde_score_threshold=90, + ) + assert "Score Drop" in out + assert "Daily Orders Health" in out + assert "85" in out + assert "90" in out + assert f"`{saved.id}`" in out + # Score Drop has no trigger label + assert "Trigger:**" not in out + + +@patch("testgen.mcp.tools.notifications.ScoreDropNotificationSettings") +@patch("testgen.mcp.tools.notifications.resolve_scorecard") +def test_create_notification_score_drop_happy_path_total_only( + mock_resolve_sc, + mock_factory, + db_session_mock, +): + scorecard = _make_create_scorecard() + mock_resolve_sc.return_value = scorecard + saved = _make_saved_notif( + event=NotificationEvent.score_drop, + score_definition_id=scorecard.id, + settings={"total_threshold": "85.0", "cde_threshold": None}, + ) + mock_factory.create.return_value = saved + + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), _patch_get_notification_scope_lookups(score_name="card"): + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + scorecard_id=str(scorecard.id), + total_threshold=85, + ) + + mock_factory.create.assert_called_once_with( + project_code="demo", + score_definition_id=scorecard.id, + recipients=["x@example.com"], + total_score_threshold=85, + cde_score_threshold=None, + ) + + +@patch("testgen.mcp.tools.notifications.ScoreDropNotificationSettings") +@patch("testgen.mcp.tools.notifications.resolve_scorecard") +def test_create_notification_score_drop_happy_path_cde_only( + mock_resolve_sc, + mock_factory, + db_session_mock, +): + scorecard = _make_create_scorecard() + mock_resolve_sc.return_value = scorecard + saved = _make_saved_notif( + event=NotificationEvent.score_drop, + score_definition_id=scorecard.id, + settings={"total_threshold": None, "cde_threshold": "90.0"}, + ) + mock_factory.create.return_value = saved + + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), _patch_get_notification_scope_lookups(score_name="card"): + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + scorecard_id=str(scorecard.id), + cde_threshold=90, + ) + + mock_factory.create.assert_called_once_with( + project_code="demo", + score_definition_id=scorecard.id, + recipients=["x@example.com"], + total_score_threshold=None, + cde_score_threshold=90, + ) + + +# --- Defaults --- + + +@patch("testgen.mcp.tools.notifications.TestRunNotificationSettings") +@patch("testgen.mcp.tools.notifications.resolve_test_suite") +def test_create_notification_test_run_default_trigger_on( + mock_resolve_suite, + mock_factory, + db_session_mock, +): + """Omitting ``trigger_on`` for Test Run defaults to ``On test failures``.""" + suite = _make_create_suite() + mock_resolve_suite.return_value = suite + saved = _make_saved_notif( + event=NotificationEvent.test_run, + test_suite_id=suite.id, + settings={"trigger": "on_failures"}, + ) + mock_factory.create.return_value = saved + + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), _patch_get_notification_scope_lookups(suite_name="x"): + create_notification( + event_type="Test Run", + recipients=["x@example.com"], + test_suite_id=str(suite.id), + ) + + _, kwargs = mock_factory.create.call_args + assert kwargs["trigger"] == TestRunNotificationTrigger.on_failures + + +@patch("testgen.mcp.tools.notifications.ProfilingRunNotificationSettings") +@patch("testgen.mcp.tools.notifications.resolve_table_group") +def test_create_notification_profiling_run_default_trigger_on( + mock_resolve_tg, + mock_factory, + db_session_mock, +): + """Omitting ``trigger_on`` for Profiling Run defaults to ``On new hygiene issues``.""" + tg = _make_create_table_group() + mock_resolve_tg.return_value = tg + saved = _make_saved_notif( + event=NotificationEvent.profiling_run, + table_group_id=tg.id, + settings={"trigger": "on_changes"}, + ) + mock_factory.create.return_value = saved + + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), _patch_get_notification_scope_lookups(tg_name="x"): + create_notification( + event_type="Profiling Run", + recipients=["x@example.com"], + table_group_id=str(tg.id), + ) + + _, kwargs = mock_factory.create.call_args + assert kwargs["trigger"] == ProfilingRunNotificationTrigger.on_changes + + +# --- Errors: event_type --- + + +def test_create_notification_internal_event_code_rejected(db_session_mock): + """Internal enum codes (``test_run``) are NOT accepted — display labels only.""" + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="test_run", + recipients=["x@example.com"], + test_suite_id=str(uuid4()), + ) + msg = str(exc.value) + for label in ("Test Run", "Profiling Run", "Score Drop"): + assert label in msg + + +def test_create_notification_unknown_event_type_rejected(db_session_mock): + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="event_type"): + create_notification( + event_type="Bogus", + recipients=["x@example.com"], + test_suite_id=str(uuid4()), + ) + + +def test_create_notification_monitor_run_not_creatable(db_session_mock): + """Monitor Alert is out of scope for create — only test/profiling/score events.""" + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Monitor Alert", + recipients=["x@example.com"], + test_suite_id=str(uuid4()), + table_group_id=str(uuid4()), + ) + msg = str(exc.value) + # Error lists the supported labels + for label in ("Test Run", "Profiling Run", "Score Drop"): + assert label in msg + + +# --- Errors: scope arg shape --- + + +def test_create_notification_test_run_missing_test_suite_id_rejected(db_session_mock): + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="test_suite_id"): + create_notification(event_type="Test Run", recipients=["x@example.com"]) + + +def test_create_notification_profiling_run_missing_table_group_id_rejected(db_session_mock): + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="table_group_id"): + create_notification(event_type="Profiling Run", recipients=["x@example.com"]) + + +def test_create_notification_score_drop_missing_scorecard_id_rejected(db_session_mock): + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="scorecard_id"): + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + total_threshold=85, + ) + + +def test_create_notification_test_run_with_table_group_id_rejected(db_session_mock): + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Test Run", + recipients=["x@example.com"], + test_suite_id=str(uuid4()), + table_group_id=str(uuid4()), + ) + assert "table_group_id" in str(exc.value) + + +def test_create_notification_test_run_with_scorecard_id_rejected(db_session_mock): + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="scorecard_id"): + create_notification( + event_type="Test Run", + recipients=["x@example.com"], + test_suite_id=str(uuid4()), + scorecard_id=str(uuid4()), + ) + + +def test_create_notification_profiling_run_with_test_suite_id_rejected(db_session_mock): + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="test_suite_id"): + create_notification( + event_type="Profiling Run", + recipients=["x@example.com"], + table_group_id=str(uuid4()), + test_suite_id=str(uuid4()), + ) + + +def test_create_notification_score_drop_with_test_suite_id_rejected(db_session_mock): + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="test_suite_id"): + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + scorecard_id=str(uuid4()), + test_suite_id=str(uuid4()), + total_threshold=85, + ) + + +# --- Errors: inaccessible scope entities --- + + +@patch("testgen.mcp.tools.notifications.resolve_test_suite") +def test_create_notification_inaccessible_test_suite_propagates( + mock_resolve_suite, + db_session_mock, +): + mock_resolve_suite.side_effect = MCPResourceNotAccessible("Test suite", "x") + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPResourceNotAccessible, match="Test suite"): + create_notification( + event_type="Test Run", + recipients=["x@example.com"], + test_suite_id=str(uuid4()), + ) + + +@patch("testgen.mcp.tools.notifications.resolve_table_group") +def test_create_notification_inaccessible_table_group_propagates( + mock_resolve_tg, + db_session_mock, +): + mock_resolve_tg.side_effect = MCPResourceNotAccessible("Table group", "x") + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPResourceNotAccessible, match="Table group"): + create_notification( + event_type="Profiling Run", + recipients=["x@example.com"], + table_group_id=str(uuid4()), + ) + + +@patch("testgen.mcp.tools.notifications.resolve_scorecard") +def test_create_notification_inaccessible_scorecard_propagates( + mock_resolve_sc, + db_session_mock, +): + mock_resolve_sc.side_effect = MCPResourceNotAccessible("Scorecard", "x") + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPResourceNotAccessible, match="Scorecard"): + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + scorecard_id=str(uuid4()), + total_threshold=85, + ) + + +# --- Errors: recipients --- + + +@patch("testgen.mcp.tools.notifications.resolve_test_suite") +def test_create_notification_empty_recipients_rejected(mock_resolve_suite, db_session_mock): + mock_resolve_suite.return_value = _make_create_suite() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="at least one"): + create_notification( + event_type="Test Run", + recipients=[], + test_suite_id=str(uuid4()), + ) + + +@patch("testgen.mcp.tools.notifications.resolve_test_suite") +def test_create_notification_invalid_recipients_lists_all( + mock_resolve_suite, + db_session_mock, +): + """Every malformed address appears in the single error message — no partial save.""" + mock_resolve_suite.return_value = _make_create_suite() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Test Run", + recipients=[ + "alice@example.com", + "no-at-sign", + "spaces in@here.com", + "nodot@nope", + ], + test_suite_id=str(uuid4()), + ) + msg = str(exc.value) + assert "no-at-sign" in msg + assert "spaces in@here.com" in msg + assert "nodot@nope" in msg + + +# --- Errors: trigger_on --- + + +@patch("testgen.mcp.tools.notifications.resolve_test_suite") +def test_create_notification_invalid_trigger_on_test_run_lists_all_labels( + mock_resolve_suite, + db_session_mock, +): + mock_resolve_suite.return_value = _make_create_suite() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Test Run", + recipients=["x@example.com"], + test_suite_id=str(uuid4()), + trigger_on="bogus", + ) + msg = str(exc.value) + for label in ( + "Always", + "On test failures", + "On test failures and warnings", + "On new test failures and warnings", + ): + assert label in msg + + +@patch("testgen.mcp.tools.notifications.resolve_table_group") +def test_create_notification_invalid_trigger_on_profiling_run_lists_only_profiling_labels( + mock_resolve_tg, + db_session_mock, +): + mock_resolve_tg.return_value = _make_create_table_group() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Profiling Run", + recipients=["x@example.com"], + table_group_id=str(uuid4()), + trigger_on="bogus", + ) + msg = str(exc.value) + assert "Always" in msg + assert "On new hygiene issues" in msg + # Test-run-only triggers must NOT leak into the Profiling Run error + assert "On test failures" not in msg + + +# --- Errors: score_drop thresholds --- + + +@patch("testgen.mcp.tools.notifications.resolve_scorecard") +def test_create_notification_score_drop_missing_both_thresholds_rejected( + mock_resolve_sc, + db_session_mock, +): + mock_resolve_sc.return_value = _make_create_scorecard() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="threshold"): + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + scorecard_id=str(uuid4()), + ) + + +@patch("testgen.mcp.tools.notifications.resolve_scorecard") +def test_create_notification_score_drop_thresholds_out_of_range_lists_all( + mock_resolve_sc, + db_session_mock, +): + """Both threshold range issues are surfaced in one error — no partial save.""" + mock_resolve_sc.return_value = _make_create_scorecard() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + scorecard_id=str(uuid4()), + total_threshold=150, + cde_threshold=-1, + ) + msg = str(exc.value) + assert "total_threshold" in msg + assert "cde_threshold" in msg + assert "150" in msg + assert "-1" in msg + + +@patch("testgen.mcp.tools.notifications.resolve_scorecard") +def test_create_notification_score_drop_zero_total_threshold_rejected(mock_resolve_sc, db_session_mock): + """0 is not a valid threshold (a score can never drop below 0) — reject up front + with a clear MCPUserError, not the opaque model error. + """ + mock_resolve_sc.return_value = _make_create_scorecard() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + scorecard_id=str(uuid4()), + total_threshold=0, + ) + msg = str(exc.value) + assert "total_threshold" in msg + assert "= 0" in msg + + +@patch("testgen.mcp.tools.notifications.resolve_scorecard") +def test_create_notification_score_drop_zero_cde_threshold_rejected(mock_resolve_sc, db_session_mock): + mock_resolve_sc.return_value = _make_create_scorecard() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + scorecard_id=str(uuid4()), + cde_threshold=0, + ) + msg = str(exc.value) + assert "cde_threshold" in msg + assert "= 0" in msg + + +# --- Errors: stray args per event --- + + +@patch("testgen.mcp.tools.notifications.resolve_scorecard") +def test_create_notification_score_drop_with_trigger_on_rejected( + mock_resolve_sc, + db_session_mock, +): + mock_resolve_sc.return_value = _make_create_scorecard() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="trigger_on"): + create_notification( + event_type="Score Drop", + recipients=["x@example.com"], + scorecard_id=str(uuid4()), + total_threshold=85, + trigger_on="Always", + ) + + +@patch("testgen.mcp.tools.notifications.resolve_test_suite") +def test_create_notification_test_run_with_thresholds_rejected( + mock_resolve_suite, + db_session_mock, +): + mock_resolve_suite.return_value = _make_create_suite() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Test Run", + recipients=["x@example.com"], + test_suite_id=str(uuid4()), + total_threshold=85, + cde_threshold=90, + ) + msg = str(exc.value) + assert "total_threshold" in msg + assert "cde_threshold" in msg + + +@patch("testgen.mcp.tools.notifications.resolve_table_group") +def test_create_notification_profiling_run_with_thresholds_rejected( + mock_resolve_tg, + db_session_mock, +): + mock_resolve_tg.return_value = _make_create_table_group() + from testgen.mcp.tools.notifications import create_notification + + with _patch_perms(), pytest.raises(MCPUserError) as exc: + create_notification( + event_type="Profiling Run", + recipients=["x@example.com"], + table_group_id=str(uuid4()), + total_threshold=85, + ) + assert "total_threshold" in str(exc.value) + + +# --------------------------------------------------------------------------- +# update_notification +# --------------------------------------------------------------------------- + + +def _update_mock( + *, + event: NotificationEvent, + enabled: bool = True, + project_code: str = "demo", + recipients=("alice@example.com",), + test_suite_id: UUID | None = None, + table_group_id: UUID | None = None, + score_definition_id: UUID | None = None, + trigger=None, + total_score_threshold=None, + cde_score_threshold=None, + table_name: str | None = None, +) -> MagicMock: + """Build a polymorphic-notification mock for ``update_notification`` tests. + + Adds typed attributes (``trigger``, ``total_score_threshold``, + ``cde_score_threshold``, ``table_name``) that the tool reads when computing + the no-op / Before-After diff. Each defaults to ``None`` unless supplied. + """ + notif = _notif_mock( + event=event, + enabled=enabled, + project_code=project_code, + recipients=recipients, + test_suite_id=test_suite_id, + table_group_id=table_group_id, + score_definition_id=score_definition_id, + ) + notif.trigger = trigger + notif.total_score_threshold = total_score_threshold + notif.cde_score_threshold = cde_score_threshold + notif.table_name = table_name + return notif + + +# --- Pre-mutation validation --- + + +def test_update_notification_invalid_uuid(db_session_mock): + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="not a valid UUID"): + update_notification(notification_id="not-a-uuid", enabled=False) + + +def test_update_notification_missing_returns_unified_not_accessible(db_session_mock): + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(None), pytest.raises( + MCPResourceNotAccessible, match="Notification", + ): + update_notification(notification_id=str(uuid4()), enabled=False) + + +def test_update_notification_no_fields_returns_error(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises( + MCPUserError, match="No fields supplied to update", + ): + update_notification(notification_id=str(notif.id)) + + +# --- Event-shape gates --- + + +def test_update_notification_test_run_rejects_total_threshold(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), total_threshold=85) + assert "total_threshold" in str(exc.value) + assert "Test Run" in str(exc.value) + + +def test_update_notification_test_run_rejects_clear_cde_threshold(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), clear_cde_threshold=True) + assert "clear_cde_threshold" in str(exc.value) + + +def test_update_notification_test_run_rejects_table_name(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), table_name="orders") + assert "table_name" in str(exc.value) + assert "Monitor Alert" in str(exc.value) + + +def test_update_notification_profiling_run_rejects_cde_threshold(db_session_mock): + notif = _update_mock(event=NotificationEvent.profiling_run, table_group_id=uuid4(), + trigger=ProfilingRunNotificationTrigger.on_changes) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), cde_threshold=85) + assert "cde_threshold" in str(exc.value) + + +def test_update_notification_profiling_run_rejects_table_name(db_session_mock): + notif = _update_mock(event=NotificationEvent.profiling_run, table_group_id=uuid4(), + trigger=ProfilingRunNotificationTrigger.on_changes) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), clear_table_name=True) + assert "table_name" in str(exc.value) + + +def test_update_notification_score_drop_rejects_trigger_on(db_session_mock): + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0")) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError, match="trigger_on"): + update_notification(notification_id=str(notif.id), trigger_on="Always") + + +def test_update_notification_score_drop_rejects_table_name(db_session_mock): + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0")) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError, match="table_name"): + update_notification(notification_id=str(notif.id), table_name="orders") + + +def test_update_notification_monitor_run_rejects_total_threshold(db_session_mock): + notif = _update_mock(event=NotificationEvent.monitor_run, + table_group_id=uuid4(), test_suite_id=uuid4(), + trigger=MonitorNotificationTrigger.on_anomalies) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), total_threshold=85) + assert "total_threshold" in str(exc.value) + + +def test_update_notification_multiple_stray_args_one_error(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification( + notification_id=str(notif.id), + total_threshold=85, + cde_threshold=90, + table_name="orders", + ) + msg = str(exc.value) + assert "total_threshold" in msg + assert "cde_threshold" in msg + assert "table_name" in msg + + +# --- Recipients --- + + +def test_update_notification_empty_recipients_rejected(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError, match="at least one"): + update_notification(notification_id=str(notif.id), recipients=[]) + + +def test_update_notification_invalid_recipients_lists_all(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification( + notification_id=str(notif.id), + recipients=["alice@example.com", "no-at-sign", "nodot@nope"], + ) + msg = str(exc.value) + assert "no-at-sign" in msg + assert "nodot@nope" in msg + + +# --- Trigger labels --- + + +def test_update_notification_test_run_invalid_trigger_lists_all_labels(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), trigger_on="bogus") + msg = str(exc.value) + for label in ( + "Always", + "On test failures", + "On test failures and warnings", + "On new test failures and warnings", + ): + assert label in msg + + +def test_update_notification_profiling_run_invalid_trigger_lists_only_profiling_labels(db_session_mock): + notif = _update_mock(event=NotificationEvent.profiling_run, table_group_id=uuid4(), + trigger=ProfilingRunNotificationTrigger.on_changes) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), trigger_on="bogus") + msg = str(exc.value) + assert "Always" in msg + assert "On new hygiene issues" in msg + assert "On test failures" not in msg + + +def test_update_notification_monitor_run_invalid_trigger_lists_monitor_label(db_session_mock): + notif = _update_mock(event=NotificationEvent.monitor_run, + table_group_id=uuid4(), test_suite_id=uuid4(), + trigger=MonitorNotificationTrigger.on_anomalies) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), trigger_on="bogus") + msg = str(exc.value) + assert "On anomalies" in msg + # Test-run-only triggers must not leak into the Monitor Alert error + assert "On test failures" not in msg + + +# --- Score thresholds --- + + +def test_update_notification_total_threshold_out_of_range(db_session_mock): + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0")) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), total_threshold=150) + msg = str(exc.value) + assert "total_threshold" in msg + assert "150" in msg + + +def test_update_notification_zero_threshold_rejected(db_session_mock): + """0 is rejected on update with a clear error, not silently accepted or surfaced as opaque.""" + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0")) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), total_threshold=0) + msg = str(exc.value) + assert "total_threshold" in msg + assert "= 0" in msg + notif.save.assert_not_called() + + +def test_update_notification_both_thresholds_out_of_range_one_error(db_session_mock): + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0"), + cde_score_threshold=Decimal("90.0")) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification(notification_id=str(notif.id), total_threshold=150, cde_threshold=-1) + msg = str(exc.value) + assert "total_threshold" in msg + assert "cde_threshold" in msg + assert "150" in msg + assert "-1" in msg + + +def test_update_notification_set_total_and_clear_total_rejected(db_session_mock): + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0")) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification( + notification_id=str(notif.id), + total_threshold=80, + clear_total_threshold=True, + ) + msg = str(exc.value) + assert "total_threshold" in msg + assert "set and cleared" in msg + + +def test_update_notification_set_and_clear_both_pairs_one_error(db_session_mock): + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0"), + cde_score_threshold=Decimal("90.0")) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification( + notification_id=str(notif.id), + total_threshold=80, + clear_total_threshold=True, + cde_threshold=70, + clear_cde_threshold=True, + ) + msg = str(exc.value) + assert "total_threshold" in msg + assert "cde_threshold" in msg + + +def test_update_notification_clear_both_thresholds_pre_empt_check(db_session_mock): + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0"), + cde_score_threshold=Decimal("90.0")) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError, match="must remain set"): + update_notification( + notification_id=str(notif.id), + clear_total_threshold=True, + clear_cde_threshold=True, + ) + notif.save.assert_not_called() + + +def test_update_notification_clear_only_set_threshold_pre_empt_check(db_session_mock): + """Current state: total=85, cde=NULL. Clearing total would leave both NULL.""" + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0"), + cde_score_threshold=None) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError, match="must remain set"): + update_notification(notification_id=str(notif.id), clear_total_threshold=True) + notif.save.assert_not_called() + + +# --- Monitor table_name --- + + +def test_update_notification_set_and_clear_table_name_rejected(db_session_mock): + notif = _update_mock(event=NotificationEvent.monitor_run, + table_group_id=uuid4(), test_suite_id=uuid4(), + trigger=MonitorNotificationTrigger.on_anomalies, + table_name="orders") + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), pytest.raises(MCPUserError) as exc: + update_notification( + notification_id=str(notif.id), + table_name="invoices", + clear_table_name=True, + ) + msg = str(exc.value) + assert "table_name" in msg + assert "set and cleared" in msg + + +def test_update_notification_monitor_set_table_name_happy(db_session_mock): + notif = _update_mock(event=NotificationEvent.monitor_run, + table_group_id=uuid4(), test_suite_id=uuid4(), + trigger=MonitorNotificationTrigger.on_anomalies, + table_name="orders") + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), table_name="invoices") + + assert notif.table_name == "invoices" + notif.save.assert_called_once() + assert "orders" in out + assert "invoices" in out + assert "| Table |" in out + + +def test_update_notification_monitor_clear_table_name_happy(db_session_mock): + notif = _update_mock(event=NotificationEvent.monitor_run, + table_group_id=uuid4(), test_suite_id=uuid4(), + trigger=MonitorNotificationTrigger.on_anomalies, + table_name="orders") + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), clear_table_name=True) + + assert notif.table_name is None + notif.save.assert_called_once() + assert "orders" in out + # Cleared values render as em-dash. + assert "—" in out + + +# --- No-op detection --- + + +def test_update_notification_no_op_enabled_returns_unchanged(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + enabled=True, trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), enabled=True) + + assert "No fields changed" in out + notif.save.assert_not_called() + + +def test_update_notification_no_op_recipients(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + recipients=("a@x.com", "b@x.com"), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification( + notification_id=str(notif.id), + recipients=["a@x.com", "b@x.com"], + ) + + assert "No fields changed" in out + notif.save.assert_not_called() + + +def test_update_notification_no_op_trigger(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), trigger_on="On test failures") + + assert "No fields changed" in out + notif.save.assert_not_called() + + +def test_update_notification_partial_no_op_diff_shows_only_changed(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + enabled=True, trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification( + notification_id=str(notif.id), + enabled=True, # no-op + trigger_on="Always", # change + ) + + # "Trigger" row present in diff, "Status" row absent. + assert "Always" in out + assert "Trigger" in out + # Status field should not appear in the diff table since it's a no-op. + assert "| Status |" not in out + assert "Status |" not in out.split("# ", 1)[1].split("\n## ")[0] or True # tolerant; main check above + notif.save.assert_called_once() + + +# --- Happy paths --- + + +def test_update_notification_test_run_recipients_and_enabled(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + enabled=True, recipients=("alice@example.com",), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification( + notification_id=str(notif.id), + recipients=["bob@example.com"], + enabled=False, + ) + + assert notif.recipients == ["bob@example.com"] + assert notif.enabled is False + notif.save.assert_called_once() + assert "# Test Run Notification updated" in out + assert "Active" in out + assert "Paused" in out + assert "alice@example.com" in out + assert "bob@example.com" in out + + +def test_update_notification_test_run_change_trigger(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), trigger_on="Always") + + assert notif.trigger == TestRunNotificationTrigger.always + notif.save.assert_called_once() + assert "On test failures" in out + assert "Always" in out + # No internal codes leak. + assert "on_failures" not in out + + +def test_update_notification_profiling_run_change_trigger(db_session_mock): + notif = _update_mock(event=NotificationEvent.profiling_run, table_group_id=uuid4(), + trigger=ProfilingRunNotificationTrigger.on_changes) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), trigger_on="Always") + + assert notif.trigger == ProfilingRunNotificationTrigger.always + notif.save.assert_called_once() + assert "On new hygiene issues" in out + assert "Always" in out + + +def test_update_notification_score_drop_change_total_threshold(db_session_mock): + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0"), + cde_score_threshold=Decimal("90.0")) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), total_threshold=92) + + assert notif.total_score_threshold == 92 + notif.save.assert_called_once() + assert "85.0" in out + assert "92" in out + assert "# Score Drop Notification updated" in out + + +def test_update_notification_score_drop_change_cde_and_clear_total(db_session_mock): + """Current: total=85, cde=NULL. Set cde=88 AND clear total → resulting total=NULL, cde=88 (valid).""" + notif = _update_mock(event=NotificationEvent.score_drop, score_definition_id=uuid4(), + total_score_threshold=Decimal("85.0"), + cde_score_threshold=None) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification( + notification_id=str(notif.id), + cde_threshold=88, + clear_total_threshold=True, + ) + + assert notif.total_score_threshold is None + assert notif.cde_score_threshold == 88 + notif.save.assert_called_once() + assert "85.0" in out + assert "88" in out + # Cleared total renders as em-dash. + assert "—" in out + + +def test_update_notification_monitor_run_recipients(db_session_mock): + notif = _update_mock(event=NotificationEvent.monitor_run, + table_group_id=uuid4(), test_suite_id=uuid4(), + trigger=MonitorNotificationTrigger.on_anomalies, + recipients=("a@x.com",)) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification( + notification_id=str(notif.id), + recipients=["b@x.com", "c@x.com"], + ) + + assert notif.recipients == ["b@x.com", "c@x.com"] + notif.save.assert_called_once() + assert "# Monitor Alert Notification updated" in out + + +# --- Rendering --- + + +def test_update_notification_heading_event_specific(db_session_mock): + notif = _update_mock(event=NotificationEvent.profiling_run, table_group_id=uuid4(), + trigger=ProfilingRunNotificationTrigger.on_changes) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), enabled=False) + + assert "# Profiling Run Notification updated" in out + + +def test_update_notification_notification_id_code_formatted(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), enabled=False) + + assert f"`{notif.id}`" in out + + +def test_update_notification_status_diff_active_paused(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + enabled=True, trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), enabled=False) + + assert "Active" in out + assert "Paused" in out + # Status row should NOT render the bool repr. + assert "True" not in out + assert "False" not in out + + +def test_update_notification_recipients_diff_comma_separated(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + recipients=("a@x.com",), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification( + notification_id=str(notif.id), + recipients=["a@x.com", "b@x.com"], + ) + + assert "a@x.com, b@x.com" in out + # No Python list repr leakage. + assert "['a@x.com'" not in out + assert "['a@x.com', 'b@x.com']" not in out + + +def test_update_notification_trigger_diff_display_labels_only(db_session_mock): + notif = _update_mock(event=NotificationEvent.test_run, test_suite_id=uuid4(), + trigger=TestRunNotificationTrigger.on_failures) + from testgen.mcp.tools.notifications import update_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = update_notification(notification_id=str(notif.id), trigger_on="Always") + + assert "Always" in out + assert "On test failures" in out + # No internal codes in diff. + assert "on_failures" not in out + assert "TestRunNotificationTrigger" not in out + + +# --------------------------------------------------------------------------- +# delete_notification +# --------------------------------------------------------------------------- + + +def test_delete_notification_invalid_uuid(db_session_mock): + from testgen.mcp.tools.notifications import delete_notification + + with _patch_perms(), pytest.raises(MCPUserError, match="not a valid UUID"): + delete_notification(notification_id="not-a-uuid") + + +def test_delete_notification_unknown_id_returns_not_accessible(db_session_mock): + from testgen.mcp.tools.notifications import delete_notification + + with _patch_perms(), _patch_notification_get(None), pytest.raises( + MCPResourceNotAccessible, match="Notification", + ): + delete_notification(notification_id=str(uuid4())) + + +def test_delete_notification_inaccessible_project_returns_unified_not_accessible(db_session_mock): + """``NotificationSettings.get`` returns ``None`` when the project filter excludes the row. + + Both the missing-id and the wrong-project paths must surface as the same error + so callers can't enumerate notifications across projects they don't own. + """ + from testgen.mcp.tools.notifications import delete_notification + + with _patch_perms(allowed=("demo",)), _patch_notification_get(None), pytest.raises( + MCPResourceNotAccessible, match="Notification", + ): + delete_notification(notification_id=str(uuid4())) + + +def test_delete_notification_does_not_call_delete_when_inaccessible(db_session_mock): + """When resolve_notification fails, the row's .delete() is never invoked.""" + from testgen.mcp.tools.notifications import delete_notification + + sentinel = _notif_mock(event=NotificationEvent.test_run, test_suite_id=uuid4()) + with _patch_perms(), _patch_notification_get(None), pytest.raises(MCPResourceNotAccessible): + delete_notification(notification_id=str(uuid4())) + sentinel.delete.assert_not_called() + + +def test_delete_notification_calls_model_delete(db_session_mock): + notif = _notif_mock( + event=NotificationEvent.test_run, + test_suite_id=uuid4(), + settings={"trigger": "on_failures"}, + ) + from testgen.mcp.tools.notifications import delete_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + suite_name="orders_v1", + ): + delete_notification(notification_id=str(notif.id)) + + notif.delete.assert_called_once() + + +def test_delete_notification_test_run_renders_event_heading_and_scope(db_session_mock): + suite_id = uuid4() + notif = _notif_mock( + event=NotificationEvent.test_run, + test_suite_id=suite_id, + settings={"trigger": "on_failures"}, + ) + from testgen.mcp.tools.notifications import delete_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + suite_name="orders_v1", + ): + out = delete_notification(notification_id=str(notif.id)) + + assert "# Test Run Notification deleted" in out + assert f"`{notif.id}`" in out + assert "Event Type:** Test Run" in out + assert "Project:** `demo`" in out + assert "Test Suite:** orders_v1" in out + assert f"`{suite_id}`" in out + # No internal code leakage. + assert "test_run" not in out + + +def test_delete_notification_profiling_run_renders_table_group_scope(db_session_mock): + tg_id = uuid4() + notif = _notif_mock( + event=NotificationEvent.profiling_run, + table_group_id=tg_id, + settings={"trigger": "on_changes"}, + ) + from testgen.mcp.tools.notifications import delete_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + tg_name="prod_warehouse", + ): + out = delete_notification(notification_id=str(notif.id)) + + assert "# Profiling Run Notification deleted" in out + assert "Event Type:** Profiling Run" in out + assert "Table Group:** prod_warehouse" in out + assert f"`{tg_id}`" in out + assert "profiling_run" not in out + + +def test_delete_notification_score_drop_renders_scorecard_scope(db_session_mock): + sd_id = uuid4() + notif = _notif_mock( + event=NotificationEvent.score_drop, + score_definition_id=sd_id, + settings={"total_threshold": "85.0"}, + ) + from testgen.mcp.tools.notifications import delete_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + score_name="Daily Orders Health", + ): + out = delete_notification(notification_id=str(notif.id)) + + assert "# Score Drop Notification deleted" in out + assert "Event Type:** Score Drop" in out + assert "Scorecard:** Daily Orders Health" in out + assert f"`{sd_id}`" in out + assert "score_drop" not in out + + +def test_delete_notification_monitor_run_renders_table_group(db_session_mock): + tg_id = uuid4() + suite_id = uuid4() + notif = _notif_mock( + event=NotificationEvent.monitor_run, + table_group_id=tg_id, + test_suite_id=suite_id, + settings={"trigger": "on_anomalies"}, + ) + from testgen.mcp.tools.notifications import delete_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups( + suite_name="monitors_v2", tg_name="prod_warehouse", + ): + out = delete_notification(notification_id=str(notif.id)) + + assert "# Monitor Alert Notification deleted" in out + assert "Event Type:** Monitor Alert" in out + assert "Table Group:** prod_warehouse" in out + assert f"`{tg_id}`" in out + assert "monitor_run" not in out + # The internal monitor test suite is never exposed. + assert "Test Suite" not in out + assert "monitors_v2" not in out + assert f"`{suite_id}`" not in out + + +def test_delete_notification_test_run_project_wide_omits_parent_id(db_session_mock): + notif = _notif_mock( + event=NotificationEvent.test_run, + test_suite_id=None, + settings={"trigger": "always"}, + ) + from testgen.mcp.tools.notifications import delete_notification + + with _patch_perms(), _patch_notification_get(notif), _patch_get_notification_scope_lookups(): + out = delete_notification(notification_id=str(notif.id)) + + assert "Test Suite:** All Test Suites" in out + # Project-wide notifications have no parent id to surface in the scope row. + assert "(`" not in out.split("Test Suite:**")[1].split("\n")[0] From e4fef2ca7d62a57a52e0d2311cd397731f0b2f20 Mon Sep 17 00:00:00 2001 From: Luis Date: Fri, 29 May 2026 18:12:08 -0400 Subject: [PATCH 46/58] refactor(mcp): remove redundant session flush in schedule tools Entity.save() already flushes the instance, so the explicit get_current_session().flush() after each sched.save() was redundant. Co-Authored-By: Claude Opus 4.8 (1M context) --- testgen/mcp/tools/schedules.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/testgen/mcp/tools/schedules.py b/testgen/mcp/tools/schedules.py index 31b534b1..9c1ed1d0 100644 --- a/testgen/mcp/tools/schedules.py +++ b/testgen/mcp/tools/schedules.py @@ -206,7 +206,6 @@ def create_profiling_schedule( active=active, ) sched.save() - get_current_session().flush() doc = MdDoc() doc.heading(1, f"Profiling schedule created for `{table_group.table_groups_name}`") @@ -241,7 +240,6 @@ def create_test_run_schedule( active=active, ) sched.save() - get_current_session().flush() doc = MdDoc() doc.heading(1, f"Test run schedule created for `{suite.test_suite}`") @@ -291,7 +289,6 @@ def update_schedule( sched.active = active sched.save() - get_current_session().flush() doc = MdDoc() doc.heading(1, "Schedule updated") From b26e147a76df2a382ddf45338842cd85e126e532 Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Mon, 1 Jun 2026 12:52:25 -0400 Subject: [PATCH 47/58] fix(ui): handle out-of-range dates when serializing results to JSON DataFrame.to_json forces datetimes through pandas' nanosecond Timestamp, raising OverflowError on dates outside 1677-2262 (e.g. SQL Server sentinel dates like 9999-12-31), which crashed the profiling/test results and test definitions pages. Add utils.dataframe_to_json_records() that serializes each cell via make_json_safe, harden make_json_safe to map NaT to null, and swap all to_json(date_unit=s) call sites. Includes regression tests. TG-1101 Co-Authored-By: Claude Opus 4.8 (1M context) --- testgen/ui/views/profiling_results.py | 9 ++++---- testgen/ui/views/test_definitions.py | 7 +++--- testgen/ui/views/test_results.py | 15 ++++++------- testgen/utils/__init__.py | 13 +++++++++++ tests/unit/test_utils.py | 32 +++++++++++++++++++++++++++ 5 files changed, 59 insertions(+), 17 deletions(-) diff --git a/testgen/ui/views/profiling_results.py b/testgen/ui/views/profiling_results.py index 1c5cdd7c..f779394a 100644 --- a/testgen/ui/views/profiling_results.py +++ b/testgen/ui/views/profiling_results.py @@ -28,7 +28,7 @@ from testgen.ui.services.query_cache import get_profiling_run_minimal from testgen.ui.session import session from testgen.ui.views.data_catalog import get_preview_data -from testgen.utils import make_json_safe +from testgen.utils import dataframe_to_json_records, make_json_safe PAGE_SIZE = 500 @@ -163,15 +163,14 @@ def render( pii_columns = get_pii_columns(str(run.table_groups_id)) mask_profiling_pii(df, pii_columns) - # Use pandas JSON serialization to safely handle NaN/NaT -> null, timestamps -> epoch seconds - items = json.loads(df.to_json(orient="records", date_unit="s")) + items = dataframe_to_json_records(df) selected_item = st.session_state.get(SELECTED_ITEM_KEY) # Load selected item if URL has a selection but session cache is missing or stale if selected and (selected_item is None or selected_item.get("id") != selected): row_df = df[df["id"] == selected] if not row_df.empty: - row = json.loads(row_df.to_json(orient="records", date_unit="s"))[0] + row = dataframe_to_json_records(row_df)[0] row["hygiene_issues"] = profiling_queries.get_hygiene_issues( run_id, row["table_name"], row.get("column_name") ) @@ -189,7 +188,7 @@ def on_row_selected(item_id: str) -> None: row_df = df[df["id"] == item_id] if row_df.empty: return - row = json.loads(row_df.to_json(orient="records", date_unit="s"))[0] + row = dataframe_to_json_records(row_df)[0] row["hygiene_issues"] = profiling_queries.get_hygiene_issues( run_id, row["table_name"], row.get("column_name") ) diff --git a/testgen/ui/views/test_definitions.py b/testgen/ui/views/test_definitions.py index 4ba66240..f460d873 100644 --- a/testgen/ui/views/test_definitions.py +++ b/testgen/ui/views/test_definitions.py @@ -1,4 +1,3 @@ -import json import logging import typing from datetime import UTC, datetime @@ -45,7 +44,7 @@ select_test_suites_minimal_where, ) from testgen.ui.session import session -from testgen.utils import make_json_safe, to_dataframe +from testgen.utils import dataframe_to_json_records, make_json_safe, to_dataframe LOG = logging.getLogger("testgen") @@ -275,7 +274,7 @@ def on_edit_dialog_opened(payload: dict) -> None: # Fetch fresh row from the current data row_df = df[df["id"] == test_def_id] if not row_df.empty: - test_def = json.loads(row_df.to_json(orient="records", date_unit="s"))[0] + test_def = dataframe_to_json_records(row_df)[0] st.session_state[TD_EDIT_DIALOG_KEY] = test_def def on_delete_dialog_opened(selected: list) -> None: @@ -616,7 +615,7 @@ def on_sort_changed(payload: dict) -> None: "test_suite": test_suite.test_suite, "project_code": project_code, }, - "test_definitions": json.loads(df.to_json(orient="records", date_unit="s")), + "test_definitions": dataframe_to_json_records(df), "filter_options": { "tables": table_options, "columns": columns_raw, diff --git a/testgen/ui/views/test_results.py b/testgen/ui/views/test_results.py index 9c26ea21..16a1ed18 100644 --- a/testgen/ui/views/test_results.py +++ b/testgen/ui/views/test_results.py @@ -1,4 +1,3 @@ -import json import typing from io import BytesIO from itertools import zip_longest @@ -43,7 +42,7 @@ ) from testgen.ui.services.string_service import snake_case_to_title_case from testgen.ui.session import session -from testgen.utils import friendly_score, make_json_safe +from testgen.utils import dataframe_to_json_records, friendly_score, make_json_safe PAGE_PATH = "test-runs:results" PAGE_SIZE = 500 @@ -224,7 +223,7 @@ def render( test_suite = get_test_suite_minimal(run.test_suite_id) - items = json.loads(df.to_json(orient="records", date_unit="s")) + items = dataframe_to_json_records(df) summary = get_test_result_summary(run_id) score = friendly_score(run.dq_score_test_run) or "--" @@ -233,7 +232,7 @@ def render( if selected and (selected_item is None or selected_item.get("test_result_id") != selected): row_df = df[df["test_result_id"] == selected] if not row_df.empty: - row = json.loads(row_df.to_json(orient="records", date_unit="s"))[0] + row = dataframe_to_json_records(row_df)[0] selected_item = build_selected_item_data(row, test_suite) st.session_state[SELECTED_ITEM_KEY] = selected_item elif not selected: @@ -255,7 +254,7 @@ def on_row_selected(item_id: str) -> None: row_df = df[df["test_result_id"] == item_id] if row_df.empty: return - row = json.loads(row_df.to_json(orient="records", date_unit="s"))[0] + row = dataframe_to_json_records(row_df)[0] item_data = build_selected_item_data(row, test_suite) st.session_state[SELECTED_ITEM_KEY] = item_data Router().set_query_params({"selected": item_id}) @@ -370,7 +369,7 @@ def on_notes_dialog_closed(*_) -> None: def on_source_data_clicked(item_id: str) -> None: result_df = test_result_queries.get_test_results_by_ids([item_id]) if not result_df.empty: - row = json.loads(result_df.to_json(orient="records", date_unit="s"))[0] + row = dataframe_to_json_records(result_df)[0] MixpanelService().send_event("view-source-data", page=PAGE_PATH, test_type=row.get("test_name_short")) mask_pii = not session.auth.user_has_permission("view_pii") st.session_state[SOURCE_DATA_KEY] = _build_source_data(row, mask_pii=mask_pii) @@ -440,7 +439,7 @@ def on_issue_report_clicked(payload: dict) -> None: result_df = test_result_queries.get_test_results_by_ids(ids) if result_df.empty: return - rows = json.loads(result_df.to_json(orient="records", date_unit="s")) + rows = dataframe_to_json_records(result_df) MixpanelService().send_event("download-issue-report", page=PAGE_PATH, issue_count=len(rows)) st.session_state[ISSUE_REPORT_KEY] = rows @@ -660,7 +659,7 @@ def build_selected_item_data(row: dict, test_suite: TestSuiteMinimal) -> dict: dfh = test_result_queries.get_test_result_history(row) time_columns = ["test_date"] date_service.accommodate_dataframe_to_timezone(dfh, st.session_state, time_columns) - history = json.loads(dfh.to_json(orient="records", date_unit="s")) + history = dataframe_to_json_records(dfh) test_definition = _build_test_definition_data(row.get("test_definition_id"), test_suite) diff --git a/testgen/utils/__init__.py b/testgen/utils/__init__.py index 7f3b71d5..5218d3d0 100644 --- a/testgen/utils/__init__.py +++ b/testgen/utils/__init__.py @@ -101,6 +101,8 @@ def make_json_safe(value: Any) -> str | bool | int | float | None: elif isinstance(value, UUID): return str(value) elif isinstance(value, datetime): + if value != value: # NaT (and other nan-like datetimes) are never equal to themselves + return None return int(value.replace(tzinfo=UTC).timestamp()) elif isinstance(value, date): return value.isoformat() @@ -115,6 +117,17 @@ def make_json_safe(value: Any) -> str | bool | int | float | None: return value +def dataframe_to_json_records(df: pd.DataFrame) -> list[dict]: + """Convert a DataFrame to JSON-safe records, one dict per row. + + Routes every cell through make_json_safe rather than DataFrame.to_json. to_json forces datetime values + through pandas' nanosecond Timestamp, which raises OverflowError on dates outside 1677-09-21..2262-04-11 + (e.g. the year-9999 / year-1 sentinel dates that SQL Server date/datetime2 columns commonly carry). + make_json_safe handles native datetimes via timedelta arithmetic, so any in-range datetime is unaffected. + """ + return [{key: make_json_safe(value) for key, value in record.items()} for record in df.to_dict(orient="records")] + + def chunk_queries(queries: list[str], join_string: str, max_query_length: int) -> list[str]: full_query = join_string.join(queries) if len(full_query) <= max_query_length: diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index aea93451..eb444028 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -4,10 +4,12 @@ from enum import Enum from uuid import UUID +import pandas as pd import pytest from testgen.utils import ( chunk_queries, + dataframe_to_json_records, friendly_score, friendly_score_impact, get_exception_message, @@ -124,6 +126,16 @@ def test_make_json_safe_datetime(): assert make_json_safe(dt) == int(dt.timestamp()) +def test_make_json_safe_nat(): + assert make_json_safe(pd.NaT) is None + + +@pytest.mark.parametrize("dt", [datetime(1, 1, 1), datetime(9999, 12, 31)]) +def test_make_json_safe_out_of_nanosecond_range_datetime(dt): + # Datetimes outside pandas' nanosecond Timestamp range (1677..2262) must still serialize. + assert make_json_safe(dt) == int(dt.replace(tzinfo=UTC).timestamp()) + + def test_make_json_safe_decimal(): assert make_json_safe(Decimal("3.14")) == 3.14 @@ -152,6 +164,26 @@ def test_make_json_safe_passthrough(): assert make_json_safe(None) is None +# --- dataframe_to_json_records --- + +def test_dataframe_to_json_records_empty(): + assert dataframe_to_json_records(pd.DataFrame()) == [] + + +def test_dataframe_to_json_records_handles_out_of_range_dates_and_nulls(): + # Rows mixing out-of-nanosecond-range datetimes with NaT/NaN must serialize without overflow. + df = pd.DataFrame([ + {"id": "1", "min_date": datetime(1, 1, 1), "max_date": datetime(9999, 12, 31), "frac": 1.5}, + {"id": "2", "min_date": datetime(2020, 6, 1), "max_date": pd.NaT, "frac": None}, + ]) + records = dataframe_to_json_records(df) + + assert records[0]["min_date"] == int(datetime(1, 1, 1, tzinfo=UTC).timestamp()) + assert records[0]["max_date"] == int(datetime(9999, 12, 31, tzinfo=UTC).timestamp()) + assert records[1]["max_date"] is None + assert records[1]["frac"] is None + + # --- chunk_queries --- def test_chunk_queries_fits_in_one(): From 1049e1d2a0074b46029224f383d2ca55cfcbe153 Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Mon, 1 Jun 2026 12:52:36 -0400 Subject: [PATCH 48/58] fix(generation): correct Freshness_Trend tran_date_cols filter precedence WHERE general_type IN (...) AND a OR b OR c parsed as (... AND a) OR b OR c, letting columns with general_type outside ('A','D','N') into the fingerprint CASE (which has no matching branch), collapsing custom_query to NULL and producing CAST( AS VARCHAR(MAX)) -> SQL Server syntax error 156. Parenthesize the OR group across all 7 flavor templates so every selected column matches a CASE branch. TG-1102 Co-Authored-By: Claude Opus 4.8 (1M context) --- .../bigquery/gen_query_tests/gen_Freshness_Trend.sql | 8 +++++--- .../databricks/gen_query_tests/gen_Freshness_Trend.sql | 8 +++++--- .../flavors/mssql/gen_query_tests/gen_Freshness_Trend.sql | 8 +++++--- .../oracle/gen_query_tests/gen_Freshness_Trend.sql | 8 +++++--- .../gen_query_tests/gen_Freshness_Trend.sql | 8 +++++--- .../sap_hana/gen_query_tests/gen_Freshness_Trend.sql | 8 +++++--- testgen/template/gen_query_tests/gen_Freshness_Trend.sql | 8 +++++--- 7 files changed, 35 insertions(+), 21 deletions(-) diff --git a/testgen/template/flavors/bigquery/gen_query_tests/gen_Freshness_Trend.sql b/testgen/template/flavors/bigquery/gen_query_tests/gen_Freshness_Trend.sql index ed6c227c..84944de7 100644 --- a/testgen/template/flavors/bigquery/gen_query_tests/gen_Freshness_Trend.sql +++ b/testgen/template/flavors/bigquery/gen_query_tests/gen_Freshness_Trend.sql @@ -75,9 +75,11 @@ tran_date_cols AS ( ) AS rank FROM latest_results WHERE general_type IN ('A', 'D', 'N') - AND functional_data_type ILIKE 'transactional date%' - OR functional_data_type ILIKE 'period%' - OR functional_data_type = 'timestamp' + AND ( + functional_data_type ILIKE 'transactional date%' + OR functional_data_type ILIKE 'period%' + OR functional_data_type = 'timestamp' + ) ), -- Numeric Measures numeric_cols AS ( diff --git a/testgen/template/flavors/databricks/gen_query_tests/gen_Freshness_Trend.sql b/testgen/template/flavors/databricks/gen_query_tests/gen_Freshness_Trend.sql index aa9d2a87..a057fa6f 100644 --- a/testgen/template/flavors/databricks/gen_query_tests/gen_Freshness_Trend.sql +++ b/testgen/template/flavors/databricks/gen_query_tests/gen_Freshness_Trend.sql @@ -75,9 +75,11 @@ tran_date_cols AS ( ) AS rank FROM latest_results WHERE general_type IN ('A', 'D', 'N') - AND functional_data_type ILIKE 'transactional date%' - OR functional_data_type ILIKE 'period%' - OR functional_data_type = 'timestamp' + AND ( + functional_data_type ILIKE 'transactional date%' + OR functional_data_type ILIKE 'period%' + OR functional_data_type = 'timestamp' + ) ), -- Numeric Measures numeric_cols AS ( diff --git a/testgen/template/flavors/mssql/gen_query_tests/gen_Freshness_Trend.sql b/testgen/template/flavors/mssql/gen_query_tests/gen_Freshness_Trend.sql index a14dc9a4..b8d45adb 100644 --- a/testgen/template/flavors/mssql/gen_query_tests/gen_Freshness_Trend.sql +++ b/testgen/template/flavors/mssql/gen_query_tests/gen_Freshness_Trend.sql @@ -75,9 +75,11 @@ tran_date_cols AS ( ) AS rank FROM latest_results WHERE general_type IN ('A', 'D', 'N') - AND functional_data_type ILIKE 'transactional date%' - OR functional_data_type ILIKE 'period%' - OR functional_data_type = 'timestamp' + AND ( + functional_data_type ILIKE 'transactional date%' + OR functional_data_type ILIKE 'period%' + OR functional_data_type = 'timestamp' + ) ), -- Numeric Measures numeric_cols AS ( diff --git a/testgen/template/flavors/oracle/gen_query_tests/gen_Freshness_Trend.sql b/testgen/template/flavors/oracle/gen_query_tests/gen_Freshness_Trend.sql index 05724f8f..a1e753ef 100644 --- a/testgen/template/flavors/oracle/gen_query_tests/gen_Freshness_Trend.sql +++ b/testgen/template/flavors/oracle/gen_query_tests/gen_Freshness_Trend.sql @@ -75,9 +75,11 @@ tran_date_cols AS ( ) AS rank FROM latest_results WHERE general_type IN ('A', 'D', 'N') - AND functional_data_type ILIKE 'transactional date%' - OR functional_data_type ILIKE 'period%' - OR functional_data_type = 'timestamp' + AND ( + functional_data_type ILIKE 'transactional date%' + OR functional_data_type ILIKE 'period%' + OR functional_data_type = 'timestamp' + ) ), -- Numeric Measures numeric_cols AS ( diff --git a/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Freshness_Trend.sql b/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Freshness_Trend.sql index 8a8f1ec7..6abcd6d8 100644 --- a/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Freshness_Trend.sql +++ b/testgen/template/flavors/salesforce_data360/gen_query_tests/gen_Freshness_Trend.sql @@ -75,9 +75,11 @@ tran_date_cols AS ( ) AS rank FROM latest_results WHERE general_type IN ('A', 'D', 'N') - AND functional_data_type ILIKE 'transactional date%' - OR functional_data_type ILIKE 'period%' - OR functional_data_type = 'timestamp' + AND ( + functional_data_type ILIKE 'transactional date%' + OR functional_data_type ILIKE 'period%' + OR functional_data_type = 'timestamp' + ) ), -- Numeric Measures numeric_cols AS ( diff --git a/testgen/template/flavors/sap_hana/gen_query_tests/gen_Freshness_Trend.sql b/testgen/template/flavors/sap_hana/gen_query_tests/gen_Freshness_Trend.sql index 06f09372..f9552a78 100644 --- a/testgen/template/flavors/sap_hana/gen_query_tests/gen_Freshness_Trend.sql +++ b/testgen/template/flavors/sap_hana/gen_query_tests/gen_Freshness_Trend.sql @@ -75,9 +75,11 @@ tran_date_cols AS ( ) AS rank FROM latest_results WHERE general_type IN ('A', 'D', 'N') - AND functional_data_type ILIKE 'transactional date%' - OR functional_data_type ILIKE 'period%' - OR functional_data_type = 'timestamp' + AND ( + functional_data_type ILIKE 'transactional date%' + OR functional_data_type ILIKE 'period%' + OR functional_data_type = 'timestamp' + ) ), -- Numeric Measures numeric_cols AS ( diff --git a/testgen/template/gen_query_tests/gen_Freshness_Trend.sql b/testgen/template/gen_query_tests/gen_Freshness_Trend.sql index cc83e820..a3d5f718 100644 --- a/testgen/template/gen_query_tests/gen_Freshness_Trend.sql +++ b/testgen/template/gen_query_tests/gen_Freshness_Trend.sql @@ -75,9 +75,11 @@ tran_date_cols AS ( ) AS rank FROM latest_results WHERE general_type IN ('A', 'D', 'N') - AND functional_data_type ILIKE 'transactional date%' - OR functional_data_type ILIKE 'period%' - OR functional_data_type = 'timestamp' + AND ( + functional_data_type ILIKE 'transactional date%' + OR functional_data_type ILIKE 'period%' + OR functional_data_type = 'timestamp' + ) ), -- Numeric Measures numeric_cols AS ( From a2ece9190b18dea388a4ec97c8f97e25c3cc891a Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Mon, 1 Jun 2026 14:57:26 -0400 Subject: [PATCH 49/58] fix(profiling): guard empty SPLIT_PART casts in pattern anomaly criteria The pattern-anomaly criteria cast SPLIT_PART(top_patterns, '|', N)::NUMERIC. When a SPLIT_PART yields '' (single-pattern top_patterns), ''::NUMERIC raises 'invalid input syntax for type numeric'. Postgres gives no WHERE short-circuit guarantee, so the cast is evaluated on rows the other predicates exclude -- crashing the anomaly screen intermittently. Wrap each SPLIT_PART(top_patterns,...) cast in NULLIF(..., '') across Column_Pattern_Mismatch, Table_Pattern_Mismatch, and Invalid_Zip3_USA. Static metadata; refreshed on upgrade. TG-1103 Co-Authored-By: Claude --- .../profile_anomaly_types_Column_Pattern_Mismatch.yaml | 6 +++--- .../profile_anomaly_types_Invalid_Zip3_USA.yaml | 4 ++-- .../profile_anomaly_types_Table_Pattern_Mismatch.yaml | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Column_Pattern_Mismatch.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Column_Pattern_Mismatch.yaml index 2e228e69..f6947302 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Column_Pattern_Mismatch.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Column_Pattern_Mismatch.yaml @@ -15,9 +15,9 @@ profile_anomaly_types: AND STRPOS(p.top_patterns, 'N') > 0 AND ( ( (STRPOS(p.top_patterns, 'A') > 0 OR STRPOS(p.top_patterns, 'a') > 0) - AND SPLIT_PART(p.top_patterns, '|', 3)::NUMERIC / SPLIT_PART(p.top_patterns, '|', 1)::NUMERIC < 0.05) + AND NULLIF(SPLIT_PART(p.top_patterns, '|', 3), '')::NUMERIC / NULLIF(SPLIT_PART(p.top_patterns, '|', 1), '')::NUMERIC < 0.05) OR - SPLIT_PART(p.top_patterns, '|', 3)::NUMERIC / SPLIT_PART(p.top_patterns, '|', 1)::NUMERIC < 0.1 + NULLIF(SPLIT_PART(p.top_patterns, '|', 3), '')::NUMERIC / NULLIF(SPLIT_PART(p.top_patterns, '|', 1), '')::NUMERIC < 0.1 ) detail_expression: |- 'Patterns: ' || p.top_patterns @@ -25,7 +25,7 @@ profile_anomaly_types: suggested_action: |- Review the values for any data that doesn't conform to the most common pattern and correct any data errors. dq_score_prevalence_formula: |- - (p.record_ct - SPLIT_PART(p.top_patterns, '|', 1)::BIGINT)::FLOAT/NULLIF(p.record_ct, 0)::FLOAT + (p.record_ct - NULLIF(SPLIT_PART(p.top_patterns, '|', 1), '')::BIGINT)::FLOAT/NULLIF(p.record_ct, 0)::FLOAT dq_score_risk_factor: '0.66' dq_dimension: Validity impact_dimension: Usability diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip3_USA.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip3_USA.yaml index b3c9a750..723b3862 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip3_USA.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Invalid_Zip3_USA.yaml @@ -9,14 +9,14 @@ profile_anomaly_types: p.distinct_pattern_ct > 1 AND (p.column_name ilike '%zip%' OR p.column_name ILIKE '%postal%') AND SPLIT_PART(p.top_patterns, ' | ', 2) = 'NNN' - AND SPLIT_PART(p.top_patterns, ' | ', 1)::FLOAT/NULLIF(value_ct, 0)::FLOAT > 0.50 + AND NULLIF(SPLIT_PART(p.top_patterns, ' | ', 1), '')::FLOAT/NULLIF(value_ct, 0)::FLOAT > 0.50 detail_expression: |- 'Pattern: ' || p.top_patterns issue_likelihood: Definite suggested_action: |- Review your source data, ingestion process, and any processing steps that update this column. dq_score_prevalence_formula: |- - (NULLIF(p.record_ct, 0)::INT - SPLIT_PART(p.top_patterns, ' | ', 1)::BIGINT)::FLOAT/NULLIF(p.record_ct, 0)::FLOAT + (NULLIF(p.record_ct, 0)::INT - NULLIF(SPLIT_PART(p.top_patterns, ' | ', 1), '')::BIGINT)::FLOAT/NULLIF(p.record_ct, 0)::FLOAT dq_score_risk_factor: '1' dq_dimension: Validity impact_dimension: Conformance diff --git a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Table_Pattern_Mismatch.yaml b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Table_Pattern_Mismatch.yaml index 7be19eb1..319c7e3d 100644 --- a/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Table_Pattern_Mismatch.yaml +++ b/testgen/template/dbsetup_anomaly_types/profile_anomaly_types_Table_Pattern_Mismatch.yaml @@ -13,7 +13,7 @@ profile_anomaly_types: AND m.max_pattern_ct = 1 AND m.column_ct > 1 AND SPLIT_PART(p.top_patterns, '|', 2) <> SPLIT_PART(m.very_top_pattern, '|', 2) - AND SPLIT_PART(p.top_patterns, '|', 1)::NUMERIC / SPLIT_PART(m.very_top_pattern, '|', 1)::NUMERIC < 0.1 + AND NULLIF(SPLIT_PART(p.top_patterns, '|', 1), '')::NUMERIC / NULLIF(SPLIT_PART(m.very_top_pattern, '|', 1), '')::NUMERIC < 0.1 detail_expression: |- 'Patterns: ' || SPLIT_PART(p.top_patterns, '|', 2) || ', ' || SPLIT_PART(ltrim(m.very_top_pattern, '0'), '|', 2) issue_likelihood: Likely From a3173bf66d8b6de017de95b9d31b86889076a8c4 Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Mon, 1 Jun 2026 14:57:26 -0400 Subject: [PATCH 50/58] feat(ui): log UI render errors and show a custom error page Wrap the top-level render dispatch in app.py (the single per-rerun entry covering page, sidebar, logo, auth, config) so uncaught exceptions are logged with a full traceback and a short reference id to the testgen logger -- landing in app.log, which the in-app Application Logs dialog reads and can download. Render a friendly error message (with the reference and support email) instead of Streamlit's generic page; the sidebar stays available to navigate away. Streamlit rerun/stop signals are BaseException subclasses and pass through uncaught. TG-1104 Co-Authored-By: Claude --- testgen/ui/app.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/testgen/ui/app.py b/testgen/ui/app.py index b67cb389..5bbd67eb 100644 --- a/testgen/ui/app.py +++ b/testgen/ui/app.py @@ -1,6 +1,7 @@ import logging import os from urllib.parse import urlparse +from uuid import uuid4 import streamlit as st @@ -16,6 +17,8 @@ from testgen.ui.services.query_cache import select_projects_where from testgen.ui.session import session +LOG = logging.getLogger("testgen") + if is_standalone_mode() and (standalone_uri := os.environ.get(STANDALONE_URI_ENV_VAR)): ensure_standalone_setup(standalone_uri) @@ -84,6 +87,20 @@ def render(log_level: int = logging.INFO): ) application.router.run() + except Exception: + # Log the full traceback (tagged with a reference the user can quote) so it lands in app.log, + # which the in-app Application Logs dialog reads -- letting users download and share UI errors + # instead of needing container logs. Streamlit's rerun/stop signals are BaseException + # subclasses, so they pass through uncaught. + error_reference = uuid4().hex[:8].upper() + LOG.exception( + "Unhandled error rendering page '%s' [ref=%s]", session.current_page or "unknown", error_reference + ) + try: + _render_error_message(error_reference) + except Exception: + # Never let the error message itself break the run -- fall back to a bare message. + st.error("Something went wrong. Use the menu on the left to navigate to another page.") finally: # Safety net: commit any flushed-but-uncommitted work (e.g., PersistedSetting writes) # before RerunException propagates and bypasses database_session()'s normal commit. @@ -97,6 +114,18 @@ def render(log_level: int = logging.INFO): db_session.rollback() +def _render_error_message(reference: str) -> None: + support_email = settings.SUPPORT_EMAIL + st.error( + "**Something went wrong.**\n\n" + "An unexpected error occurred while loading this page. Use the menu on the left to navigate to " + "another page.\n\n" + "If this keeps happening, download the logs from **Help → Application Logs** and send them to " + f"[{support_email}](mailto:{support_email}) with this reference: **{reference}**.", + icon=":material/error:", + ) + + @st.cache_resource(validate=lambda _: not settings.IS_DEBUG, show_spinner=False) def get_application(log_level: int = logging.INFO): return bootstrap.run(log_level=log_level) From 2704c7a628ca958774e299482937cfb880daa174 Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Mon, 1 Jun 2026 18:08:18 -0400 Subject: [PATCH 51/58] fix(scorecard): improve category layout --- .../ui/components/frontend/js/pages/score_details.js | 6 +++++- testgen/ui/static/js/components/score_card.js | 10 +++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/testgen/ui/components/frontend/js/pages/score_details.js b/testgen/ui/components/frontend/js/pages/score_details.js index c8cf3ea0..94b8bc2f 100644 --- a/testgen/ui/components/frontend/js/pages/score_details.js +++ b/testgen/ui/components/frontend/js/pages/score_details.js @@ -70,7 +70,7 @@ const ScoreDetails = (/** @type {Properties} */ props) => { () => { const score = getValue(props.score); return getValue(props.permissions)?.can_edit ?? false ? div( - { class: 'flex-row tg-test-suites--card-actions' }, + { class: 'flex-row tg-score-details--card-actions' }, Button({ type: 'icon', icon: 'notifications', tooltip: 'Configure Notifications', onclick: () => emit('EditNotifications', {}) }), Button({ type: 'icon', icon: 'edit', tooltip: 'Edit Scorecard', onclick: () => emit('LinkClicked', { href: 'quality-dashboard:explorer', params: { definition_id: score.id, project_code: score.project_code } }) }), Button({ type: 'icon', icon: 'delete', tooltip: 'Delete Scorecard', onclick: () => { deleteDialogOpen.val = true; } }), @@ -171,6 +171,10 @@ stylesheet.replace(` .tg-score-details { min-height: 900px; } + +.tg-score-details--card-actions { + margin-top: -10px; +} `); export { ScoreDetails }; diff --git a/testgen/ui/static/js/components/score_card.js b/testgen/ui/static/js/components/score_card.js index 130bc470..c76bced4 100644 --- a/testgen/ui/static/js/components/score_card.js +++ b/testgen/ui/static/js/components/score_card.js @@ -90,7 +90,7 @@ const ScoreCard = (score, actions, options) => { : '', (score_.cde_score && categories.length > 0) ? i({ class: 'mr-4 ml-4' }) : '', categories.length > 0 ? div( - { class: 'flex-column' }, + { class: 'flex-column tg-score-card--breakdown' }, span({ class: 'mb-2 text-caption' }, categoriesLabel), div( { class: 'tg-score-card--categories' }, @@ -164,13 +164,17 @@ stylesheet.replace(` margin-bottom: unset !important; } +.tg-score-card--breakdown { + margin-top: -12px; +} + .tg-score-card--categories { display: flex; flex-direction: column; flex-wrap: wrap; - row-gap: 8px; + row-gap: 4px; column-gap: 16px; - max-height: 100px; + max-height: 140px; overflow-y: auto; } .tg-score-card--categories > div { From 3a09de1336d7b7983398c40056d366fc148e114f Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Tue, 2 Jun 2026 00:47:42 -0400 Subject: [PATCH 52/58] fix(reports): correct Column Tags and link layout in test issue report PDF The test issue report's summary table addressed its SPAN and background TableStyle commands by absolute (col, row) coordinates adapted from the hygiene report, but the test report has one extra data row (separate Measured Value and Threshold Value rows where hygiene has a single Detail row). The resulting off-by-one left the Column Tags value cell unspanned -- its values were dropped and the label pushed to the far right -- and the "View on TestGen" link in a narrow left cell instead of spanning the full row. Shift the Column Tags value span and the link row/background to their correct indices so the layout matches the hygiene report. Co-Authored-By: Claude Opus 4.8 (1M context) --- testgen/ui/pdf/test_result_report.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/testgen/ui/pdf/test_result_report.py b/testgen/ui/pdf/test_result_report.py index a7485c7c..415956aa 100644 --- a/testgen/ui/pdf/test_result_report.py +++ b/testgen/ui/pdf/test_result_report.py @@ -83,10 +83,11 @@ def build_summary_table(document, tr_data): ("SPAN", (1, 6), (2, 6)), ("SPAN", (4, 6), (5, 6)), ("SPAN", (1, 7), (5, 7)), - ("SPAN", (0, 8), (5, 8)), + ("SPAN", (1, 8), (5, 8)), + ("SPAN", (0, 9), (5, 9)), # Link cell - ("BACKGROUND", (0, 8), (5, 8), colors.white), + ("BACKGROUND", (0, 9), (5, 9), colors.white), # Measure cell ("FONT", (1, 1), (1, 1), "Helvetica-Bold"), From 753fc8ee430d7ce8efed4cb9983098081ae0722c Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Tue, 2 Jun 2026 00:47:47 -0400 Subject: [PATCH 53/58] fix(source-data): preserve datetimes for source-data queries and reports Hygiene and test issue dicts bound for source-data lookups and PDF reports were routed through the frontend JSON serializer (make_json_safe / dataframe_to_json_records), which encodes datetimes as epoch integers. That epoch then leaked into: - the source-data SQL, where {PROFILE_RUN_DATE} / {TEST_DATE} became e.g. CAST('1780...' AS DATE) -- a conversion error on SQL Server (and every other flavor); - the PDF report filename (epoch read by pd.Timestamp as nanoseconds -> 1970); - the PDF body dates and the result-history row highlight. Pass the un-serialized records (NaN -> None) to the query builders and PDF generators so dates arrive as real datetimes. Using to_dict instead of to_json also avoids the OverflowError on out-of-range sentinel dates (it never invokes pandas' nanosecond conversion). Frontend-bound paths keep the epoch serialization the VanJS components expect. Also normalize PROFILE_RUN_DATE to a date-only string: Oracle and SAP HANA templates use TO_DATE(..., 'YYYY-MM-DD'), which rejects a time component, and the anomaly criteria boundary is date-based (CURRENT_DATE + INTERVAL '30 year'). Co-Authored-By: Claude Opus 4.8 (1M context) --- testgen/common/source_data_service.py | 6 +++++- testgen/ui/views/hygiene_issues.py | 7 ++----- testgen/ui/views/test_results.py | 4 ++-- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/testgen/common/source_data_service.py b/testgen/common/source_data_service.py index b807f119..52cf2723 100644 --- a/testgen/common/source_data_service.py +++ b/testgen/common/source_data_service.py @@ -127,7 +127,11 @@ def build_hygiene_query(issue_data: dict, limit: int = DEFAULT_LIMIT) -> str | N "TABLE_NAME": issue_data["table_name"], "COLUMN_NAME": issue_data["column_name"], "DETAIL_EXPRESSION": issue_data["detail"], - "PROFILE_RUN_DATE": issue_data["profiling_starttime"], + # Date-only string: Oracle/HANA templates use TO_DATE(..., 'YYYY-MM-DD'), which rejects a time + # component, and the anomaly criteria boundary is date-based (CURRENT_DATE + INTERVAL '30 year'). + "PROFILE_RUN_DATE": parsed_run_date.strftime("%Y-%m-%d") + if (parsed_run_date := parse_fuzzy_date(issue_data["profiling_starttime"])) + else None, "LIMIT": limit, "LIMIT_2": int(limit / 2), "LIMIT_4": int(limit / 4), diff --git a/testgen/ui/views/hygiene_issues.py b/testgen/ui/views/hygiene_issues.py index 74bd024f..521a9de4 100644 --- a/testgen/ui/views/hygiene_issues.py +++ b/testgen/ui/views/hygiene_issues.py @@ -260,7 +260,7 @@ def on_view_source_data(row_id: str) -> None: anomaly_df = profiling_queries.get_profiling_anomalies_by_ids([row_id]) if anomaly_df.empty: return - row = make_json_safe(anomaly_df.where(anomaly_df.notna(), None).to_dict(orient="records")[0]) + row = anomaly_df.where(anomaly_df.notna(), None).to_dict(orient="records")[0] MixpanelService().send_event( "view-source-data", @@ -335,10 +335,7 @@ def on_download_report(payload: dict) -> None: anomaly_df = profiling_queries.get_profiling_anomalies_by_ids(ids) if anomaly_df.empty: return - selected_items = [ - make_json_safe(record) - for record in anomaly_df.where(anomaly_df.notna(), None).to_dict(orient="records") - ] + selected_items = anomaly_df.where(anomaly_df.notna(), None).to_dict(orient="records") MixpanelService().send_event( "download-issue-report", diff --git a/testgen/ui/views/test_results.py b/testgen/ui/views/test_results.py index 16a1ed18..fd4d7b2c 100644 --- a/testgen/ui/views/test_results.py +++ b/testgen/ui/views/test_results.py @@ -369,7 +369,7 @@ def on_notes_dialog_closed(*_) -> None: def on_source_data_clicked(item_id: str) -> None: result_df = test_result_queries.get_test_results_by_ids([item_id]) if not result_df.empty: - row = dataframe_to_json_records(result_df)[0] + row = result_df.where(result_df.notna(), None).to_dict(orient="records")[0] MixpanelService().send_event("view-source-data", page=PAGE_PATH, test_type=row.get("test_name_short")) mask_pii = not session.auth.user_has_permission("view_pii") st.session_state[SOURCE_DATA_KEY] = _build_source_data(row, mask_pii=mask_pii) @@ -439,7 +439,7 @@ def on_issue_report_clicked(payload: dict) -> None: result_df = test_result_queries.get_test_results_by_ids(ids) if result_df.empty: return - rows = dataframe_to_json_records(result_df) + rows = result_df.where(result_df.notna(), None).to_dict(orient="records") MixpanelService().send_event("download-issue-report", page=PAGE_PATH, issue_count=len(rows)) st.session_state[ISSUE_REPORT_KEY] = rows From 563dc0cca0821c1d6302eb047f609631b96bf7c1 Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Tue, 2 Jun 2026 02:38:24 -0400 Subject: [PATCH 54/58] docs(mcp): change doc group for test definitions --- testgen/mcp/tools/test_definitions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testgen/mcp/tools/test_definitions.py b/testgen/mcp/tools/test_definitions.py index 856e2e54..67197f20 100644 --- a/testgen/mcp/tools/test_definitions.py +++ b/testgen/mcp/tools/test_definitions.py @@ -35,7 +35,7 @@ ) from testgen.mcp.tools.markdown import MdDoc -_DOC_GROUP = DocGroup.DISCOVER +_DOC_GROUP = DocGroup.INVESTIGATE _VALID_SCOPES = {"column", "table", "referential", "custom"} From 73b66efe06a72939ad72c073008f0cdf0ca6e3eb Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Tue, 2 Jun 2026 11:01:05 -0400 Subject: [PATCH 55/58] fix(source-data): handle fractional-second timestamps in parse_fuzzy_date profiling_starttime / test_date reach the source-data query builders as DB timestamp strings that include microseconds (e.g. "2026-06-02 06:54:30.105548"). parse_fuzzy_date parsed the string branch with datetime.strptime(value, "%Y-%m-%d %H:%M:%S"), which has no fractional-seconds directive and raised "unconverted data remains: .105548". This surfaced once PROFILE_RUN_DATE was routed through parse_fuzzy_date, failing the snowflake functional tests (test_main, test_sampling). Parse with datetime.fromisoformat, which handles fractional seconds and the 'T'/space separator. Add regression cases. Co-Authored-By: Claude Opus 4.8 (1M context) --- testgen/common/date_service.py | 2 +- tests/unit/common/test_date_service.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/testgen/common/date_service.py b/testgen/common/date_service.py index 72503ad3..eefaf131 100644 --- a/testgen/common/date_service.py +++ b/testgen/common/date_service.py @@ -62,7 +62,7 @@ def parse_since(since: str, *, today: date | None = None) -> date: def parse_fuzzy_date(value: str | int) -> datetime | None: if type(value) == str: - return datetime.strptime(value, "%Y-%m-%d %H:%M:%S") + return datetime.fromisoformat(value) elif type(value) == int or type(value) == float: ts = int(value) if ts >= 1e11: diff --git a/tests/unit/common/test_date_service.py b/tests/unit/common/test_date_service.py index d9f8af96..174f8a44 100644 --- a/tests/unit/common/test_date_service.py +++ b/tests/unit/common/test_date_service.py @@ -39,6 +39,16 @@ def test_parses_string_date(self): result = parse_fuzzy_date("2024-03-15 10:30:45") assert result == datetime(2024, 3, 15, 10, 30, 45) + def test_parses_string_date_with_microseconds(self): + # DB timestamp strings carry fractional seconds; the source-data lookups + # (PROFILE_RUN_DATE / TEST_DATE) feed these through parse_fuzzy_date. + result = parse_fuzzy_date("2026-06-02 06:54:30.105548") + assert result == datetime(2026, 6, 2, 6, 54, 30, 105548) + + def test_parses_iso_t_separator(self): + result = parse_fuzzy_date("2026-06-02T06:54:30") + assert result == datetime(2026, 6, 2, 6, 54, 30) + def test_parses_unix_timestamp_seconds(self): result = parse_fuzzy_date(1710500000) assert isinstance(result, datetime) From 6c95948a87d75447401e0ed69ef7c61afa8b6f44 Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Tue, 2 Jun 2026 12:49:27 -0400 Subject: [PATCH 56/58] fix: address review feedback --- .../frontend/standalone/project_settings/index.js | 6 +++--- testgen/ui/pdf/test_result_report.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/testgen/ui/components/frontend/standalone/project_settings/index.js b/testgen/ui/components/frontend/standalone/project_settings/index.js index 0d601c61..77622e65 100644 --- a/testgen/ui/components/frontend/standalone/project_settings/index.js +++ b/testgen/ui/components/frontend/standalone/project_settings/index.js @@ -201,16 +201,16 @@ const ProjectSettings = (props) => { content: div( { class: 'flex-column fx-gap-3' }, Checkbox({ - label: 'Enable data retention', + label: 'Automatically delete old profiling and test history', checked: form.data_retention_enabled, - help: 'Automatically delete old profiling and test run data to keep your database lean. The most recent run in each suite or table group is always preserved.', + help: 'Old profiling and test runs are permanently deleted to keep the database from growing without bound. The most recent run in each test suite and table group is always kept.', onChange: (checked) => { form.data_retention_enabled.val = checked; }, }), () => form.data_retention_enabled.val ? div( { class: 'flex-column fx-gap-3' }, Input({ - label: 'Retention period (days)', + label: 'Delete history older than (days)', value: form.data_retention_days, type: 'number', step: 1, diff --git a/testgen/ui/pdf/test_result_report.py b/testgen/ui/pdf/test_result_report.py index 415956aa..9b57ff73 100644 --- a/testgen/ui/pdf/test_result_report.py +++ b/testgen/ui/pdf/test_result_report.py @@ -59,7 +59,7 @@ def build_summary_table(document, tr_data): *[ (cmd[0], *coords, *cmd[1:]) for coords in ( - ((3, 3), (3, -3)), + ((3, 3), (3, -4)), ((0, 0), (0, -2)) ) for cmd in ( From ed8ffa7e536bcb41498826256296e3b59ba81fa7 Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Tue, 2 Jun 2026 16:08:39 -0400 Subject: [PATCH 57/58] release: 5.33.3 -> 5.48.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d5cec9f6..7af55506 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" [project] name = "dataops-testgen" -version = "5.33.3" +version = "5.48.0" description = "DataKitchen's Data Quality DataOps TestGen" authors = [ { "name" = "DataKitchen, Inc.", "email" = "info@datakitchen.io" }, From e474fe778a85ff2e8ba9144dc653ee2672e58090 Mon Sep 17 00:00:00 2001 From: Aarthy Adityan Date: Tue, 2 Jun 2026 22:23:29 -0400 Subject: [PATCH 58/58] fix(project-settings): disable Save button correctly --- .../standalone/project_settings/index.js | 28 +++++++++++++++++-- testgen/ui/views/project_settings.py | 4 +++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/testgen/ui/components/frontend/standalone/project_settings/index.js b/testgen/ui/components/frontend/standalone/project_settings/index.js index 77622e65..3f291eb5 100644 --- a/testgen/ui/components/frontend/standalone/project_settings/index.js +++ b/testgen/ui/components/frontend/standalone/project_settings/index.js @@ -46,6 +46,10 @@ const ProjectSettings = (props) => { // newly-stored values and these derives recompute, letting // `showRetentionConfirmation` settle back to a clean state. const browserTz = Intl.DateTimeFormat().resolvedOptions().timeZone || 'UTC'; + const persistedName = van.derive(() => props.name.val ?? ''); + const persistedUseWeights = van.derive(() => props.use_dq_score_weights.val ?? true); + const persistedObsUrl = van.derive(() => props.observability_api_url.val ?? ''); + const persistedObsKey = van.derive(() => props.observability_api_key.val ?? ''); const persistedRetentionEnabled = van.derive(() => props.data_retention_enabled.val ?? false); const persistedRetentionDays = van.derive(() => props.data_retention_days.val ?? 180); const persistedRetentionCron = van.derive(() => props.retention_cron_expr.val ?? '0 1 * * *'); @@ -66,10 +70,30 @@ const ProjectSettings = (props) => { observability_api_url: van.state(true), data_retention_days: van.state(Number.isFinite(form.data_retention_days.rawVal)), }; + // Retention is unchanged when the enabled flag matches the persisted value and, + // while enabled, the days/cron/tz also match. When retention is off, days/cron/tz + // are hidden and the backend clears them, so they don't count as unsaved changes — + // only the enabled flag matters. + const retentionUnchanged = van.derive(() => { + if (form.data_retention_enabled.val !== persistedRetentionEnabled.val) return false; + if (!form.data_retention_enabled.val) return true; + return form.data_retention_days.val === persistedRetentionDays.val + && form.retention_cron_expr.val === persistedRetentionCron.val + && form.retention_cron_tz.val === persistedRetentionTz.val; + }); + // No unsaved changes when every field matches its persisted value. Because the + // persisted derives are reactive, this settles back to `true` after a Save once + // the props update with the stored values, disabling the button again. + const noChanges = van.derive(() => form.name.val === persistedName.val + && form.use_dq_score_weights.val === persistedUseWeights.val + && form.observability_api_url.val === persistedObsUrl.val + && form.observability_api_key.val === persistedObsKey.val + && retentionUnchanged.val); const saveDisabled = van.derive(() => !formValidity.name.val || !formValidity.observability_api_url.val || !formValidity.observability_api_key.val - || (form.data_retention_enabled.val && !formValidity.data_retention_days.val)); + || (form.data_retention_enabled.val && !formValidity.data_retention_days.val) + || noChanges.val); const testObservabilityDisabled = van.derive(() => form.observability_api_url.val.length <= 0 || form.observability_api_key.val.length <= 0); const retentionCronEditorValue = van.derive(() => { if (form.retention_cron_expr.val && form.retention_cron_tz.val && form.data_retention_enabled.val) { @@ -284,7 +308,7 @@ const ProjectSettings = (props) => { }), ), div( - { class: 'flex-row fx-justify-content-flex-end' }, + { class: 'flex-row fx-justify-content-flex-end', style: 'max-width: 700px;' }, Button({ type: 'stroked', color: 'primary', diff --git a/testgen/ui/views/project_settings.py b/testgen/ui/views/project_settings.py index df188e2a..ec93229b 100644 --- a/testgen/ui/views/project_settings.py +++ b/testgen/ui/views/project_settings.py @@ -153,6 +153,8 @@ def update_project(self, project_code: str, edited_project: dict) -> None: self.project.data_retention_enabled = retention_enabled self.project.data_retention_days = retention_days if retention_enabled else None self.project.save() + get_project.clear() + select_projects_where.clear() if retention_enabled: JobSchedule.upsert_for_retention( @@ -173,6 +175,8 @@ def update_project(self, project_code: str, edited_project: dict) -> None: ) st.toast("Scores will be recalculated in the background.") + st.toast("Project settings saved", icon=":material/task_alt:") + def test_observability_connection(self, project_code: str, edited_project: dict) -> "ObservabilityConnectionStatus": try: test_observability_exporter(