diff --git a/.gitignore b/.gitignore index 3cc335d..a4642c7 100644 --- a/.gitignore +++ b/.gitignore @@ -72,3 +72,7 @@ Thumbs.db # Secrets (GCP service accounts, etc.) secrets/ .opencode +.Rproj.user +CLAUDE.local.md +.Rhistory +r-archive/config.yml diff --git a/docs/CLAUDE.md b/docs/CLAUDE.md index 984b37b..797c533 100644 --- a/docs/CLAUDE.md +++ b/docs/CLAUDE.md @@ -8,25 +8,33 @@ Patient pipeline is complete and deployed to production (Cloud Run). | Module | Purpose | |--------|---------| | `extract/patient.py` | Read Excel trackers → raw parquet (openpyxl, multi-sheet) | +| `extract/product.py` | Read Excel trackers → raw product parquet (month sheets, stock section) | +| `extract/wide_format.py` | Mandalay wide-format handling (column expansion 2020-21; cell splitting 2017-19) | | `clean/patient.py` | Type conversion, validation, transformations → cleaned parquet | -| `clean/schema.py` | 83-column meta schema matching R output | +| `clean/product.py` | Product cleaning pipeline (R steps 2.0-2.21) → cleaned product parquet | +| `clean/schema.py` | 83-column patient meta schema matching R output | +| `clean/schema_product.py` | 19-column product meta schema + helpers | | `clean/converters.py` | Safe type conversion with ErrorCollector | | `clean/validators.py` | Case-insensitive allowed-values validation | | `clean/transformers.py` | Explicit transformations (regimen, BP splitting, FBG) | | `clean/date_parser.py` | Flexible date parsing (Excel serials, DD/MM/YYYY, month-year) | | `tables/patient.py` | Aggregate cleaned parquets → static, monthly, annual tables | +| `tables/product.py` | Aggregate cleaned product parquets → product_data table | | `tables/clinic.py` | Create clinic static table from reference_data/clinic_data.xlsx | | `tables/logs.py` | Aggregate error logs → logs table | +| `tables/metadata.py` | Tracker metadata table (MD5 + per-tracker output presence flags) | | `pipeline/patient.py` | Orchestrate extract+clean per tracker, parallel workers | -| `pipeline/tracker.py` | Per-tracker pipeline execution | +| `pipeline/product.py` | Product pipeline orchestration (mirrors patient, single product_data table) | +| `pipeline/tracker.py` | Per-tracker pipeline execution (patient + product) | | `pipeline/models.py` | Result dataclasses | | `gcp/storage.py` | GCS download/upload | | `gcp/bigquery.py` | BigQuery table load | | `gcp/drive.py` | Google Drive download (clinic_data.xlsx); file ID hardcoded in module | | `reference/synonyms.py` | Column name synonym mapping (YAML) | +| `reference/products.py` | Stock_Summary product reference loader (known products, categories) | | `reference/provinces.py` | Allowed province validation | | `reference/loaders.py` | YAML loading utilities | -| `state/` | State management module (exists, not yet wired into pipeline) | +| `state/` | Reserved for incremental-processing logic (design in [migration/MIGRATION_GUIDE.md](migration/MIGRATION_GUIDE.md); not yet implemented) | | `utils/` | Shared utilities | | `config.py` | Pydantic settings from `.env` / `A4D_*` env vars | | `logging.py` | loguru setup, `file_logger()` context manager | @@ -36,24 +44,28 @@ Patient pipeline is complete and deployed to production (Cloud Run). ## CLI Commands ```bash -uv run a4d process-patient # Extract + clean + tables (local run) -uv run a4d create-tables # Re-create all tables (patient, logs, clinic) from existing cleaned parquets +uv run a4d process-patient # Extract + clean + tables (local run, patient) +uv run a4d process-product # Extract + clean + table (local run, product) +uv run a4d create-tables # Re-create patient/logs/clinic tables from existing cleaned parquets +uv run a4d create-product-tables # Re-create product table from existing cleaned parquets uv run a4d upload-tables # Upload tables to BigQuery uv run a4d download-trackers # Download tracker files from GCS uv run a4d upload-output # Upload output directory to GCS uv run a4d download-reference-data # Download clinic_data.xlsx from Google Drive into reference_data/ -uv run a4d run-pipeline # Full end-to-end pipeline (drive download→GCS download→process→upload) +uv run a4d run-pipeline # Full end-to-end pipeline (patient + product arms, drive/GCS/BigQuery) ``` -Key options: `--file` (single tracker), `--workers N`, `--force`, `--skip-tables`, `--skip-download`, `--skip-upload`, `--skip-drive-download`. +Key options: `--file` (single tracker), `--workers N`, `--skip-tables`, `--skip-download`, `--skip-upload`, `--skip-drive-download`, `--skip-product`, `--incremental` (skip trackers matching previous run's manifest). ## Output Directory Structure ```text output/ -├── patient_data_raw/ # Raw extracted parquets (one per tracker) -├── patient_data_cleaned/ # Cleaned parquets (one per tracker) -├── tables/ # Final tables: static.parquet, monthly.parquet, annual.parquet, logs.parquet, clinic_data_static.parquet +├── patient_data_raw/ # Raw extracted patient parquets (one per tracker) +├── patient_data_cleaned/ # Cleaned patient parquets (one per tracker) +├── product_data_raw/ # Raw extracted product parquets (one per tracker) +├── product_data_cleaned/ # Cleaned product parquets (one per tracker) +├── tables/ # Final tables: patient_data_{static,monthly,annual}.parquet, product_data.parquet, clinic_data_static.parquet, table_logs.parquet, tracker_metadata.parquet └── logs/ # Per-tracker log files (JSON) ``` @@ -68,5 +80,6 @@ output/ ## Migration Status - **Patient pipeline**: complete, validated against 174 trackers, deployed to production -- **Product pipeline**: not yet started -- **State management**: module exists but not wired into pipeline yet +- **Product pipeline**: complete, merged into `src/a4d/` (2026-04-23). +- **Tracker metadata table**: generated on every `create-tables` / `run-pipeline` run (MD5 + output-presence flags) and uploaded to BigQuery `tracker_metadata`. +- **Incremental processing**: shipped 2026-05-01 behind the `--incremental` CLI flag (opt-in) on `process-patient`, `process-product`, and `run-pipeline`. Skips trackers whose MD5 + completion state match the previous run's manifest (BigQuery → local parquet → empty fallback). Default behaviour unchanged. See `a4d.state` module + [migration/MIGRATION_GUIDE.md](migration/MIGRATION_GUIDE.md) state-management section. diff --git a/docs/migration/MIGRATION_GUIDE.md b/docs/migration/MIGRATION_GUIDE.md index 1c85465..50834aa 100644 --- a/docs/migration/MIGRATION_GUIDE.md +++ b/docs/migration/MIGRATION_GUIDE.md @@ -2,7 +2,7 @@ Reference for the A4D pipeline migration from R to Python. -**Status**: Phases 0–7 complete. Patient pipeline production-ready. Product pipeline not yet started. +**Status**: Phases 0–9 complete. Patient pipeline production-ready. Product pipeline merged into `src/a4d/` on 2026-04-23. **Branch**: `migration` --- @@ -87,48 +87,66 @@ upload-tables # tables/*.parquet → BigQuery ``` src/a4d/ -├── extract/patient.py # Excel → raw parquet +├── extract/ +│ ├── patient.py # Patient: Excel → raw parquet +│ ├── product.py # Product: Excel month sheets → raw parquet +│ └── wide_format.py # Mandalay wide-format handlers (column/cell) ├── clean/ -│ ├── patient.py # Main cleaning pipeline -│ ├── schema.py # 83-column meta schema +│ ├── patient.py # Patient cleaning pipeline +│ ├── product.py # Product cleaning pipeline (R steps 2.0-2.21) +│ ├── schema.py # 83-column patient schema +│ ├── schema_product.py # 19-column product schema │ ├── converters.py # Safe type conversion + ErrorCollector │ ├── validators.py # Case-insensitive allowed-values │ ├── transformers.py # Explicit transformations │ └── date_parser.py # Flexible date parsing ├── tables/ │ ├── patient.py # static/monthly/annual aggregation +│ ├── product.py # product_data aggregation +│ ├── clinic.py # clinic static table │ └── logs.py # Error log aggregation ├── pipeline/ -│ ├── patient.py # Orchestration + parallel workers -│ ├── tracker.py # Per-tracker execution +│ ├── patient.py # Patient orchestration + parallel workers +│ ├── product.py # Product orchestration (mirrors patient) +│ ├── tracker.py # Per-tracker execution (patient + product) │ └── models.py # Result dataclasses ├── gcp/ │ ├── storage.py # GCS operations +│ ├── drive.py # Google Drive (clinic_data.xlsx) │ └── bigquery.py # BigQuery load ├── reference/ │ ├── synonyms.py # Column name mapping (YAML) +│ ├── products.py # Stock_Summary product reference loader │ ├── provinces.py # Allowed province validation │ └── loaders.py # YAML loading utilities ├── state/ # State management (exists, not yet wired up) ├── config.py # Pydantic settings from A4D_* env vars ├── logging.py # loguru setup ├── errors.py # Shared error types -└── cli.py # Typer CLI (6 commands) +└── cli.py # Typer CLI (patient + product commands, run-pipeline) ``` -### State Management (Designed, Not Yet Active) +### State Management (Incremental Processing) ``` 1. Container starts (stateless, fresh) 2. Query BigQuery metadata table - SELECT file_name, file_hash FROM tracker_metadata -3. Compare with current file hashes -4. Process only: new + changed + previously failed -5. Update metadata table (append new records) + SELECT file_name, clinic_code, md5, complete FROM tracker_metadata +3. Compare with current file MD5s +4. Process only: new + changed + previously incomplete +5. Re-publish metadata table (full replace) at end of run 6. Container shuts down (state persists in BigQuery) ``` -Currently: pipeline processes all trackers found in `data_root`. Incremental logic exists in `state/` but is not wired into `pipeline/patient.py` yet. +Wired up via the `a4d.state` module ([src/a4d/state/](../../src/a4d/state/)) and exposed +through the `--incremental` CLI flag on `process-patient`, `process-product`, and +`run-pipeline`. Source precedence is BigQuery → local +`output_root/tables/tracker_metadata.parquet` → empty manifest, so local devs +without `gcloud auth` get the local-parquet fallback automatically. + +The flag is **opt-in**: default behaviour is unchanged (process every tracker +found in `data_root`). Flipping the default to incremental is a separate +decision after a soak window. --- @@ -226,6 +244,7 @@ with file_logger("clinic_001_patient", output_root) as log: | 5 | Pipeline integration: `pipeline/patient.py` + parallel processing | | 6 | GCP: `gcp/storage.py`, `gcp/bigquery.py`, CLI commands | | 7 | Validation: 174 trackers compared, 8 bugs fixed, production verdict | +| 9 | Product pipeline: merged WIP product modules into `src/a4d/`; `run-pipeline` runs both arms. | --- @@ -238,18 +257,13 @@ with file_logger("clinic_001_patient", output_root) as log: - Compare dashboard reports with R pipeline baseline - Fix any issues discovered during first real run -### Phase 9: Product Pipeline - -- `extract/product.py` — same pattern as patient extraction -- `clean/product.py` — same pattern as patient cleaning -- `tables/product.py` — product aggregation tables -- Validate against R product pipeline outputs - -### State Management (Incremental Processing) +### Production Scheduling -- `state/` module exists with BigQuery state design -- Wire into `pipeline/patient.py` so only changed/new trackers are processed -- Required before production scheduling (Cloud Run + Cloud Scheduler) +Cloud Run + Cloud Scheduler wiring (cron, image build, deploy manifest). The +state module shipped behind the `--incremental` CLI flag — see the +[State Management](#state-management-incremental-processing) section above. +The default behaviour remains "process every tracker"; flipping the default to +incremental is a separate decision after a soak window. --- diff --git a/docs/PRODUCT_DATA_PIPELINE_FEATURE.md b/docs/migration/PRODUCT_DATA_PIPELINE_FEATURE.md similarity index 90% rename from docs/PRODUCT_DATA_PIPELINE_FEATURE.md rename to docs/migration/PRODUCT_DATA_PIPELINE_FEATURE.md index 8a1c5ca..863af54 100644 --- a/docs/PRODUCT_DATA_PIPELINE_FEATURE.md +++ b/docs/migration/PRODUCT_DATA_PIPELINE_FEATURE.md @@ -1,5 +1,7 @@ # Feature Proposal: Product Data Processing Pipeline +> **OBSOLETE — historical reference only.** The product data pipeline described here was implemented and merged into `src/a4d/` on 2026-04-23. This document predates that work and is preserved for context on the original proposal; the "Problem Statement", "Implementation Plan", "Success Criteria", and "Questions for User" sections no longer reflect reality. For current state, see [docs/CLAUDE.md](../CLAUDE.md) Migration Status. For the R↔Python step mapping that guided the implementation, see [product_r_to_python_mapping.md](product_r_to_python_mapping.md). + ## Summary Implement a complete product data processing pipeline in the Python codebase to match the functionality of the R pipeline. This is a critical gap as the R pipeline processes both patient AND product data, while the Python pipeline currently only handles patient data. diff --git a/docs/migration/PYTHON_IMPROVEMENTS.md b/docs/migration/PYTHON_IMPROVEMENTS.md index 09e51f0..1344112 100644 --- a/docs/migration/PYTHON_IMPROVEMENTS.md +++ b/docs/migration/PYTHON_IMPROVEMENTS.md @@ -122,6 +122,73 @@ for h1, h2 in zip(header_1, header_2, strict=True): **Impact**: Negligible - differences are below any meaningful precision threshold for BMI measurements. +## 5. Product Pipeline: Date Parsing Robustness + +**Status**: ✅ Improved in Python + +Three distinct date-parsing patterns surfaced during product-pipeline diff investigation against R goldens. In each case Python yields a more correct result than R; no R-parity fix is warranted. Investigation: [Ali_internship/residual_dig.ipynb](../../Ali_internship/residual_dig.ipynb). + +### 5.1 "Sept" → "Sep" month abbreviation + +**Issue in R**: `lubridate` does not recognize the 4-letter abbreviation "Sept" as September. Source strings like `"05-Sept-2025"` or `"25-Sept-2025"` are rejected → `null`. + +**Python Fix**: `parse_date_flexible` strips the trailing letter from any 4-letter month abbreviation before matching: + +```python +re.sub(r"([a-zA-Z]{3})[a-zA-Z]", r"\1", date_str) +``` + +This converts `"Sept"` → `"Sep"`, after which the standard `%d-%b-%Y` parse succeeds. + +**Impact**: 46 rows across 28 product `(file, sheet, product)` groups where R has `null` and Python has a valid date — the `py_set_r_null` class in the joined-view diff. Affected trackers include `2024_Putrajaya Hospital A4D Tracker` (Sep24) and `2025_NPH A4D Tracker` (Sep25). + +**File**: [src/a4d/clean/date_parser.py:67-69](../../src/a4d/clean/date_parser.py#L67-L69) + +### 5.2 D/M/YYYY single-digit-month dates + +**Issue in R**: `lubridate::dmy()` with default formats does not parse `"20/5/2025"` (single-digit month, slash separator) — rejects → `null`. + +**Python Fix**: `parse_date_flexible` includes a slash-format match path that accepts both `D/M/YYYY` and `DD/MM/YYYY`. + +**Impact**: 5 rows in `2025_Putrajaya Hospital A4D Tracker` / May25 / WIZ Test Strips with cell values like `"20/5/2025"`, `"22/5/2025"`, `"28/5/2025"`. R rejects all; Python parses them correctly. + +**File**: [src/a4d/clean/date_parser.py](../../src/a4d/clean/date_parser.py) + +### 5.3 Future-year sentinel guard for malformed date strings + +**Issue in R**: `lubridate` is permissive — strings with structural typos like `"10-Oct-2-24"` (extra `-2-` injected) get force-parsed into plausible-but-incorrect dates (e.g. `2024-02-10`). The wrong date sorts at a different position in the cumulative-balance sequence, so intermediate `product_balance` values diverge from what the source spreadsheet shows. R has no future-year guard. + +**Python Fix**: `parse_date_flexible` returns successfully or returns the sentinel `9999-09-09` for unparseable input. `_validate_entry_dates` then re-sentinels any successfully-parsed date whose year is between the tracker year and the Buddhist-era threshold (2400) — this catches fat-fingered Gregorian years (e.g. `"2099-..."` typed in a 2024 tracker) while exempting genuine Buddhist-era dates (BE 25xx → CE 20xx). Sentinelled rows sort to the end of the cumulative-balance sequence so they don't corrupt intermediate values. + +```python +BUDDHIST_ERA_THRESHOLD = 2400 +invalid_mask = ( + pl.col("product_entry_date").is_not_null() + & (pl.col("product_entry_date") > max_valid) + & (pl.col("product_entry_date").dt.year() < BUDDHIST_ERA_THRESHOLD) +) +``` + +**Impact**: 31 rows across 6 product groups (the post-Buddhist-fix `real_divergence` class in `product_balance`) — R force-parses or silently rejects malformed strings while Python correctly sentinels them. Confirmed by raw-Excel inspection of: + +- `"10-Oct-2-24"` (×2) — Putrajaya 2024 / Oct24 / WIZ Alcohol Swabs. R parses as `2024-02-10`; Python sentinels. +- `"15-Seep-225"`, `"08-Sep02025"` — NPH 2025 / Sep25. Both R and Python sentinel/null these; balance still diverges via the date-sort interaction with the legitimate Sept rows from §5.1. +- `"01-Ju-2025"` — Sarawak 2025 / Jul25. Both pipelines reject ("Ju" is too short to disambiguate Jun/Jul); balance diverges because R returns `null` (sorts in nulls-first/last position differing from sentinel) while Python returns `9999-09-09` (sorts to end deterministically). + +**File**: [src/a4d/clean/product.py:_validate_entry_dates](../../src/a4d/clean/product.py) + +## 6. Product Pipeline: Running Balance FP Precision + +**Status**: ℹ️ Negligible difference + +**Observation**: Python's vectorized `cum_sum().over([sheet, product])` and R's iterative `for (i in 1:nrow)` loop produce running balances that drift by ~5.7 × 10⁻¹⁴ at the deepest accumulation step. + +**Cause**: IEEE-754 floating-point accumulation order differs between Polars' vectorized cumsum and R's row-by-row addition. + +**Impact**: 411 rows across 80 product groups in the joined-view diff are flagged as different but classify as `fp_precision_only`. Both pipelines produce identical final-row balances per group; only sub-display-precision intermediates differ. + +**File**: [src/a4d/clean/product.py:_compute_running_balance](../../src/a4d/clean/product.py) + ## Summary | Issue | R Behavior | Python Behavior | Classification | @@ -130,17 +197,23 @@ for h1, h2 in zip(header_1, header_2, strict=True): | insulin_subtype typo | "rapic-acting" (typo) | "rapid-acting" (correct spelling) | **Python Fix** | | insulin_total_units extraction | Not extracted (header merge fails for 2024+ trackers) | Correctly extracted (unconditional header merge) | **Python Fix** | | BMI precision | 16 decimal places | 14-15 decimal places | **Negligible** | +| product entry date "Sept" | Rejects 4-letter month abbreviation → null | Truncates to "Sep", parses correctly | **Python Fix** | +| product entry date `D/M/YYYY` | Rejects single-digit-month slash format → null | Parses correctly | **Python Fix** | +| product entry date — malformed strings | Force-parses to plausible-but-incorrect dates (e.g. `"10-Oct-2-24"` → 2024-02-10), distorting cumulative balance | Sentinels to `9999-09-09`, sorts to end, preserves correct intermediate balances | **Python Fix** | +| product running balance | Iterative cumsum, slightly different IEEE-754 accumulation order | Vectorized cumsum, identical final-row balances | **Negligible** | ## Migration Validation Status -✅ **Schema**: 100% match (83 columns, all types correct) +✅ **Schema**: 100% match (83 patient columns + 19 product columns) ✅ **Extraction**: Improved (unconditional header merge fixes insulin_total_units) -✅ **Cleaning**: Improved (fixes insulin_type derivation bug, corrects insulin_subtype typo) -ℹ️ **Precision**: Acceptable float differences (~10^-15 for BMI) +✅ **Cleaning**: Improved (insulin_type, insulin_subtype, product date parsing) +ℹ️ **Precision**: Acceptable float differences (~10⁻¹⁵ BMI, ~10⁻¹⁴ product running balance) -**All 3 value differences are Python improvements over R bugs.** +**All value differences are Python improvements over R bugs or negligible precision drift.** The Python pipeline is production-ready with significant improvements over the R pipeline: + 1. **More robust header parsing** - No conditional merge that fails on 2024+ trackers 2. **Better null handling** - Correctly checks all insulin columns before derivation 3. **Correct terminology** - Uses proper medical terms ("rapid-acting" not "rapic-acting") +4. **More robust date parsing** - Accepts "Sept" and `D/M/YYYY`; sentinels malformed strings instead of force-parsing them into plausible-but-wrong dates that distort cumulative balances diff --git a/justfile b/justfile index 8c9d005..0e41a55 100644 --- a/justfile +++ b/justfile @@ -7,9 +7,7 @@ default: PROJECT := "a4dphase2" DATASET := "tracker" REGISTRY := "asia-southeast2-docker.pkg.dev/a4dphase2/a4d/pipeline" -GIT_SHA := `git rev-parse --short HEAD` IMAGE := REGISTRY + ":latest" -IMAGE_SHA := REGISTRY + ":" + GIT_SHA # ── Environment ─────────────────────────────────────────────────────────────── @@ -25,18 +23,13 @@ update: info: @echo "Python version:" @uv run python --version - @echo "\nInstalled packages:" + @echo "" + @echo "Installed packages:" @uv pip list # Clean cache and build artifacts clean: - rm -rf .ruff_cache - rm -rf .pytest_cache - rm -rf htmlcov - rm -rf .coverage - rm -rf dist - rm -rf build - rm -rf src/*.egg-info + rm -rf .ruff_cache .pytest_cache htmlcov .coverage dist build src/*.egg-info find . -type d -name __pycache__ -exec rm -rf {} + find . -type f -name "*.pyc" -delete @@ -93,19 +86,30 @@ hooks-run: # ── Local Pipeline ──────────────────────────────────────────────────────────── -# Process a single tracker file (no GCS) +# Process a single patient tracker file (no GCS) run-file FILE: uv run a4d process-patient --file "{{FILE}}" -# Process local files only, no GCS (use files already in data_root) -# Optionally pass a path: just run-local --data-root /path/to/trackers +# Process a single product tracker file (no GCS); ad-hoc/debug only — use `just run` for full runs +run-file-product FILE: + uv run a4d process-product --file "{{FILE}}" + +# Process local patient files only, no GCS (paths with spaces: use --file recipes instead) run-local *ARGS: uv run a4d process-patient {{ARGS}} -# Create tables from existing cleaned parquet files +# Process local product files only, no GCS; ad-hoc/debug only — use `just run` for full runs +run-local-product *ARGS: + uv run a4d process-product {{ARGS}} + +# Create patient tables from existing cleaned parquet files create-tables INPUT: uv run a4d create-tables --input "{{INPUT}}" +# Create product tables from existing cleaned parquet files +create-product-tables INPUT: + uv run a4d create-product-tables --input "{{INPUT}}" + # Download from GCS, process locally, no upload run-download *ARGS: uv run a4d run-pipeline --skip-upload {{ARGS}} @@ -120,9 +124,12 @@ run *ARGS: # shows one image entry instead of three (image + attestation + index) # Build Docker image tagged as :latest and : docker-build: + #!/usr/bin/env bash + set -euo pipefail + GIT_SHA=$(git rev-parse --short HEAD) docker build --provenance=false --platform=linux/amd64 \ -t {{IMAGE}} \ - -t {{IMAGE_SHA}} \ + -t {{REGISTRY}}:${GIT_SHA} \ -f Dockerfile . # Smoke test: verify the image starts and the CLI is reachable @@ -131,9 +138,12 @@ docker-smoke: # Push both :latest and : tags to Artifact Registry docker-push: docker-build + #!/usr/bin/env bash + set -euo pipefail + GIT_SHA=$(git rev-parse --short HEAD) docker push {{IMAGE}} - docker push {{IMAGE_SHA}} - @echo "Pushed: {{IMAGE}} and {{IMAGE_SHA}}" + docker push {{REGISTRY}}:${GIT_SHA} + echo "Pushed: {{IMAGE}} and {{REGISTRY}}:${GIT_SHA}" # Delete all images from Artifact Registry except :latest docker-clean: @@ -169,7 +179,9 @@ backup-bq: set -euo pipefail DATE=$(date +%Y%m%d) EXPIRY="TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 7 DAY)" - TABLES="patient_data_static patient_data_monthly patient_data_annual" + # Output data tables that get WRITE_TRUNCATE'd by load_pipeline_tables on every run. + # Keep in sync with PARQUET_TO_TABLE in src/a4d/gcp/bigquery.py when adding new pipelines. + TABLES="patient_data_static patient_data_monthly patient_data_annual product_data" for TABLE in $TABLES; do if bq show --quiet {{PROJECT}}:{{DATASET}}.${TABLE} 2>/dev/null; then SNAP="${TABLE}_${DATE}" diff --git a/readme.md b/readme.md index 3614b12..6f6fcdc 100644 --- a/readme.md +++ b/readme.md @@ -6,11 +6,12 @@ Python implementation of the A4D medical tracker data processing pipeline. 🚧 **Active Development** - Migrating from R to Python -See [Migration Documentation](../MIGRATION_OVERVIEW.md) for details. +See the [Migration Guide](docs/migration/MIGRATION_GUIDE.md) for details. ## Features -- ✅ **Incremental Processing** - Only process changed tracker files +- ✅ **Dual Pipeline Arms** - Patient and product trackers processed in one run +- ✅ **Incremental Processing** - Skip trackers whose MD5 + completion state match the previous run's manifest - ✅ **Parallel Execution** - Process multiple trackers concurrently - ✅ **Stateless GCP Deployment** - Uses BigQuery for state management - ✅ **Comprehensive Error Tracking** - Detailed error logs per patient/tracker @@ -49,42 +50,74 @@ A4D_UPLOAD_BUCKET=a4dphase2_output ### Running the Pipeline ```bash -# Full pipeline +# Full end-to-end pipeline (Drive + GCS download → patient + product → GCS + BigQuery upload) just run -# or: uv run python scripts/run_pipeline.py - -# With options -just run --max-workers 8 -just run --force # Reprocess all files -just run --skip-upload # Local testing +# or: uv run a4d run-pipeline + +# Common run-pipeline flags +just run --workers 8 +just run --skip-upload # Local testing: process but don't upload +just run --skip-download # Reuse files already in data_root +just run --skip-drive-download # Skip clinic_data.xlsx refresh from Google Drive +just run --skip-product # Patient-only run +just run --skip-patient # Product-only run (mutually exclusive with --skip-product) +just run --incremental # Skip unchanged trackers (MD5 + completion match) +just run --force # Wipe prior local outputs before each arm runs + +# Single-arm runs (no GCS download/upload) +just run-local # Patient extract + clean + tables +just run-local-product # Product extract + clean + tables +just run-file path/to/tracker.xlsx # Single patient tracker +just run-file-product path/to/tracker.xlsx # Single product tracker ``` ## Architecture ``` -Pipeline Flow: -1. Query BigQuery metadata → determine changed files -2. Process changed trackers in parallel (extract → clean → validate) -3. Aggregate individual parquets → final tables -4. Upload to BigQuery -5. Update metadata table +run-pipeline flow: +0. Download reference data (clinic_data.xlsx) from Google Drive +1. Download tracker files from GCS +2-3. Patient arm: extract → clean → tables (static, monthly, annual) +3b. Clinic static table from clinic_data.xlsx +3c. Product arm: extract → clean → product_data table +3d. Tracker metadata table (MD5 + per-tracker output presence) +3e. Product ↔ patient link validation (logging-only) +4. Upload tables/ and logs/ to GCS under YYYY/MM/DD/HHMMSS/ +5. Ingest tables into BigQuery (WRITE_TRUNCATE) ``` +Both arms share the same tracker queue. With `--incremental`, the queue is filtered once against the previous run's manifest so both arms see the same set. + +Output tables loaded into BigQuery: + +- `patient_data_static`, `patient_data_monthly`, `patient_data_annual` +- `product_data` +- `clinic_data_static` +- `tracker_metadata` +- `logs` + ## Project Structure ``` -a4d-python/ +a4d/ ├── src/a4d/ # Main package -│ ├── config.py # Pydantic settings +│ ├── cli.py # Typer CLI (process-patient, process-product, run-pipeline, …) +│ ├── config.py # Pydantic settings (A4D_* env vars) │ ├── logging.py # loguru configuration -│ ├── extract/ # Data extraction (Script 1) -│ ├── clean/ # Data cleaning (Script 2) -│ ├── tables/ # Table creation (Script 3) -│ ├── gcp/ # BigQuery & GCS integration -│ ├── state/ # State management -│ └── utils/ # Utilities +│ ├── extract/ # Sheet extraction (patient.py, product.py, wide_format.py) +│ ├── clean/ # Cleaning + schemas (patient.py, product.py, schema*.py) +│ ├── pipeline/ # Per-arm orchestration (patient.py, product.py, tracker.py) +│ ├── tables/ # Aggregation into final parquets (patient, product, clinic, metadata, logs) +│ ├── validate/ # Source-vs-output reconciliation (patient + product) +│ ├── gcp/ # BigQuery, GCS, Drive integration +│ ├── reference/ # Reference data loaders (synonyms, validation rules, provinces) +│ ├── state/ # Manifest + incremental filtering +│ └── utils/ # Shared utilities +├── reference_data/ # Shared YAML configs + clinic_data.xlsx ├── tests/ # Test suite -├── scripts/ # CLI scripts +├── scripts/ # Utility scripts +├── docs/ # Migration + feature docs +├── justfile # Development commands └── pyproject.toml # Dependencies ``` @@ -165,11 +198,14 @@ just hooks-run ### Docker ```bash -# Build Docker image +# Build Docker image (tagged :latest and :) just docker-build -# Run container locally -just docker-run +# Smoke-test the image (CLI reachable inside the container) +just docker-smoke + +# Push both tags to Artifact Registry +just docker-push # Or manually: docker build -t a4d-python:latest . @@ -213,12 +249,13 @@ just info This project is a complete rewrite of the R pipeline with: - 2-5x performance improvement +- Patient + product trackers in a single orchestrated run - Incremental processing (only changed files) - Better error tracking and logging - Simpler deployment (single Docker container) - Modern Python best practices -See migration documentation in parent directory for details. +See [docs/migration/](docs/migration/) for the migration guide and per-feature notes. ## License diff --git a/reference_data/validation_rules.yaml b/reference_data/validation_rules.yaml index 5fbb423..a60e011 100644 --- a/reference_data/validation_rules.yaml +++ b/reference_data/validation_rules.yaml @@ -12,6 +12,20 @@ # Note: Data transformations are hardcoded in src/a4d/clean/transformers.py, # not defined in YAML. +# Numeric range bounds consumed by src/a4d/validate/source_vs_output_patient.py +# only. Mirrors the hardcoded thresholds in src/a4d/clean/patient.py:510-562. +# Single-source-of-truth refactor blocked by patient.py being a golden file +# (see CLAUDE.local.md). If you change one, change the other. +# height bounds apply AFTER the cm->m auto-conversion at clean/patient.py:530-535. +numeric_ranges: + height: {min: 0.0, max: 2.3} + weight: {min: 0.0, max: 200.0} + bmi: {min: 10.0, max: 80.0} + age: {min: 0, max: 100} + hba1c_baseline: {min: 0.0, max: 25.0} + hba1c_updated: {min: 0.0, max: 25.0} + fbg_updated_mmol: {min: 0.0, max: 150.0} + analog_insulin_long_acting: allowed_values: ["N", "Y"] replace_invalid: true @@ -74,17 +88,20 @@ insulin_type: replace_invalid: true insulin_subtype: - # Note: R derives "rapic-acting" (typo) but validates against "Rapid-acting" (correct) - # This causes ALL derived values to become "Undefined" because: - # 1. Single values like "rapic-acting" don't match "Rapid-acting" - # 2. Comma-separated values like "rapic-acting,long-acting" don't match any single allowed value + # Python's builder emits comma-joined values (e.g. "pre-mixed,rapid-acting") for + # multi-insulin patients, matching R's builder. R's validator treated the CSV as a + # single opaque string and rejected it, so R output "Undefined" for every multi-insulin + # patient. Python uses allow_csv_subset=true: each comma-separated token is validated + # against allowed_values independently, then rejoined in canonical case. + # R's "rapic-acting" typo is already corrected in Python's builder (clean/patient.py). allowed_values: - "Pre-mixed" - "Short-acting" - "Intermediate-acting" - - "Rapid-acting" # R expects this, but derives "rapic-acting" (typo) + - "Rapid-acting" - "Long-acting" replace_invalid: true + allow_csv_subset: true observations_category: allowed_values: diff --git a/src/a4d/clean/converters.py b/src/a4d/clean/converters.py index ccf9d9d..fea863e 100644 --- a/src/a4d/clean/converters.py +++ b/src/a4d/clean/converters.py @@ -12,8 +12,9 @@ """ import polars as pl +from loguru import logger -from a4d.clean.date_parser import parse_date_flexible +from a4d.clean.date_parser import parse_date_flexible, rescue_date_typos from a4d.config import settings from a4d.errors import ErrorCollector @@ -126,6 +127,59 @@ def safe_convert_column( return df +def _apply_typo_rescue( + df: pl.DataFrame, + column: str, + error_collector: ErrorCollector, + file_name_col: str, + patient_id_col: str, +) -> pl.DataFrame: + """Rewrite known month-name typos in-place before parsing. + + Builds a rescue_map from unique strings, logs each affected row to + error_collector + loguru with code "typo_rescued", then applies the + substitutions column-wide. No-op if no typos match. + """ + rescue_map: dict[str, str] = {} + for s in df[column].drop_nulls().unique().to_list(): + rescued, was_rescued = rescue_date_typos(s) + if was_rescued: + rescue_map[s] = rescued + + if not rescue_map: + return df + + select_cols = [c for c in (file_name_col, patient_id_col) if c in df.columns] + for original, rescued_val in rescue_map.items(): + if select_cols: + affected = df.filter(pl.col(column) == original).select(select_cols) + for row in affected.iter_rows(named=True): + file_name = row.get(file_name_col) or "unknown" + patient_id = row.get(patient_id_col) or "unknown" + logger.bind(error_code="typo_rescued").warning( + f"date typo rescued in {column}: {original!r} -> {rescued_val!r} " + f"(file={file_name!r}, {patient_id_col}={patient_id!r})" + ) + error_collector.add_error( + file_name=str(file_name), + patient_id=str(patient_id), + column=column, + original_value=original, + error_message=f"date typo rescued: '{original}' -> '{rescued_val}'", + error_code="typo_rescued", + function_name="parse_date_column", + ) + + repl_expr = pl.col(column) + for original, rescued_val in rescue_map.items(): + repl_expr = ( + pl.when(pl.col(column) == original) + .then(pl.lit(rescued_val)) + .otherwise(repl_expr) + ) + return df.with_columns(repl_expr.alias(column)) + + def parse_date_column( df: pl.DataFrame, column: str, @@ -161,20 +215,35 @@ def parse_date_column( if column not in df.columns: return df + # Substitute known month-name typos (e.g. "MACH" -> "MAR") before parsing, + # logging each affected row so the source tracker remains visible to + # data-quality triage. Skipped silently when no typos match. + df = _apply_typo_rescue(df, column, error_collector, file_name_col, patient_id_col) + # Store original values for error reporting df = df.with_columns(pl.col(column).alias(f"_orig_{column}")) - # Apply parse_date_flexible to each value - # NOTE: Using list-based approach instead of map_elements() because - # map_elements() with return_dtype=pl.Date fails when ALL values are None - # (all-NA columns like hospitalisation_date). - # Explicit Series creation with dtype=pl.Date works because it doesn't - # require non-null values. - column_values = df[column].cast(pl.Utf8).to_list() - parsed_dates = [ - parse_date_flexible(val, error_val=settings.error_val_date) for val in column_values - ] - parsed_series = pl.Series(f"_parsed_{column}", parsed_dates, dtype=pl.Date) + # Parse each distinct string once, then map back. Tracker data has heavy + # duplication in date columns (e.g. "1/1/2024" repeating per row), so + # dedup-then-map is much faster than a per-row Python call. + # All-null columns short-circuit: map_elements can't infer Date dtype on + # an empty-after-drop_nulls Series. + col_str = df[column].cast(pl.Utf8) + unique_strs = col_str.drop_nulls().unique().to_list() + if unique_strs: + lookup = { + s: parse_date_flexible(s, error_val=settings.error_val_date) + for s in unique_strs + } + # Polars 1.34 ignores return_dtype=pl.Date when every mapped output is + # None and falls back to the input series' dtype (Utf8). Cast explicitly + # so the downstream `_parsed == error_date` comparison stays Date-vs-Date. + parsed_series = col_str.map_elements( + lambda v: lookup.get(v) if v is not None else None, + return_dtype=pl.Date, + ).cast(pl.Date).alias(f"_parsed_{column}") + else: + parsed_series = pl.Series(f"_parsed_{column}", [None] * df.height, dtype=pl.Date) df = df.with_columns(parsed_series) # Detect failures: parsed to error date diff --git a/src/a4d/clean/date_parser.py b/src/a4d/clean/date_parser.py index e33e446..d53f7bf 100644 --- a/src/a4d/clean/date_parser.py +++ b/src/a4d/clean/date_parser.py @@ -18,6 +18,28 @@ # Excel epoch: dates stored as days since this date EXCEL_EPOCH = date(1899, 12, 30) +# Known data-entry typos in month names. Patterns are applied case-insensitively +# with word boundaries so unrelated text containing these substrings (e.g. +# "CON0CT") is not rewritten. Replacements use uppercase since downstream +# parsers are case-insensitive. +TYPO_REPLACEMENTS: list[tuple[str, str]] = [ + (r"(?i)\bMACH\b", "MAR"), + (r"(?i)\b0CT\b", "OCT"), + (r"(?i)\b0CTOBER\b", "OCTOBER"), + (r"(?i)\bN0V\b", "NOV"), + (r"(?i)\bN0VEMBER\b", "NOVEMBER"), +] + + +def rescue_date_typos(s: str) -> tuple[str, bool]: + """Substitute known month-name typos. Returns (possibly-rewritten, was_rescued).""" + rescued = False + for pattern, replacement in TYPO_REPLACEMENTS: + new_s, n = re.subn(pattern, replacement, s) + if n > 0: + s, rescued = new_s, True + return s, rescued + def parse_date_flexible(date_str: str | None, error_val: str = "9999-09-09") -> date | None: """Parse date strings flexibly using Python's dateutil.parser. diff --git a/src/a4d/clean/patient.py b/src/a4d/clean/patient.py index a47e7b9..05b8623 100644 --- a/src/a4d/clean/patient.py +++ b/src/a4d/clean/patient.py @@ -318,10 +318,11 @@ def _derive_insulin_fields(df: pl.DataFrame) -> pl.DataFrame: For 2024+ trackers: - insulin_type: "human insulin" if any human column is Y, else "analog insulin" - insulin_subtype: Comma-separated list like "pre-mixed,rapid-acting,long-acting" - (will be replaced with "Undefined" by validation since - comma-separated values aren't in allowed_values) + Validation uses allow_csv_subset (see reference_data/validation_rules.yaml) to + accept each token against allowed_values and rejoin in canonical case. - NOTE: Python is CORRECT here. Comparison with R will show differences because R has a typo. + NOTE: Python is CORRECT here. Comparison with R will show differences because R has a typo + and because R's validator rejects its own multi-insulin CSV output. Args: df: Input DataFrame with individual insulin columns @@ -824,12 +825,12 @@ def _validate_dates(df: pl.DataFrame, error_collector: ErrorCollector) -> pl.Dat pl.col(col).is_not_null() & (pl.col(col) > pl.col("_max_valid_date")) ) - # Log each error - for row in invalid_dates.iter_rows(named=True): - patient_id = row.get("patient_id", "UNKNOWN") - file_name = row.get("file_name", "UNKNOWN") - original_date = row.get(col) - tracker_year = row.get("tracker_year") + # Log each error (tuple-unpack avoids per-row dict construction) + for patient_id, file_name, original_date, tracker_year in invalid_dates.select( + "patient_id", "file_name", col, "tracker_year" + ).iter_rows(): + patient_id = patient_id if patient_id is not None else "UNKNOWN" + file_name = file_name if file_name is not None else "UNKNOWN" logger.bind(error_code="invalid_value").warning( f"Patient {patient_id}: {col} = {original_date} " diff --git a/src/a4d/clean/product.py b/src/a4d/clean/product.py new file mode 100644 index 0000000..c23a422 --- /dev/null +++ b/src/a4d/clean/product.py @@ -0,0 +1,1009 @@ +"""Product data cleaning pipeline. + +Mirrors ``clean/patient.py`` architecture: a single orchestrator +(``clean_product_data``) dispatches to step-scoped private helpers. Covers +R Script 2 steps 2.1-2.22. +""" + +from pathlib import Path + +import polars as pl + +from a4d.clean.converters import parse_date_column, safe_convert_column +from a4d.clean.schema_product import apply_schema, get_product_data_schema, get_string_columns +from a4d.config import settings +from a4d.errors import ErrorCollector +from a4d.reference.products import load_known_products, load_product_categories + +ACTIVITY_COLS: tuple[str, ...] = ( + "product_entry_date", + "product_units_received", + "product_received_from", + "product_units_released", + "product_released_to", + "product_units_returned", + "product_returned_by", +) + +UNIT_COLS: tuple[str, ...] = ( + "product_units_received", + "product_units_released", + "product_units_returned", +) + +EMPTY_ROW_COLS: tuple[str, ...] = ( + "product_units_received", + "product_units_released", + "product_units_returned", + "product_released_to", + "product_entry_date", + "product_balance", +) + +# Buddhist Era / Common Era disambiguation point. Mandalay trackers carry +# Buddhist-era dates (BE = CE + 543); a parsed year >= 2400 is implausibly +# Gregorian and treated as either Buddhist-era or already-sentinel +# (error_val_date 9999-09-09). _validate_entry_dates skips these so they +# flow through unchanged, matching R, while still flagging fat-finger +# Gregorian futures (e.g. 2099 in a 2024 tracker). +BUDDHIST_ERA_THRESHOLD: int = 2400 + +# Lower-bound year guard for _validate_entry_dates. A parsed Gregorian +# year more than YEAR_FLOOR_DELTA years before the tracker's calendar +# year is implausible (start-balance backfill is months-to-a-few-years, +# not decades) and almost certainly a tiny-int Excel-serial mis-coercion +# (e.g. a raw cell holding `29` parsing to 1900-01-29). +YEAR_FLOOR_DELTA: int = 5 + + +def clean_product_data( + df_raw: pl.DataFrame, + error_collector: ErrorCollector, +) -> pl.DataFrame: + """Clean raw product data through the full R Script 2 sequence. + + Executes steps 2.1-2.22 in order and returns a DataFrame conforming to + the product meta schema defined in ``clean/schema_product.py``. + + Args: + df_raw: Raw product DataFrame from extraction. + error_collector: Accumulator for row-level data quality errors. + + Returns: + Cleaned product DataFrame. + """ + df = _normalize_empty_strings_to_null(df_raw) # 2.0 (see helper docstring) + df = _split_multi_product_cells(df) # 2.1 + df = _switch_misplaced_columns(df) # 2.3 + df = _remove_uninformative_rows(df) # 2.4 + df = _add_row_index(df) # 2.5 + df = _format_dates(df, error_collector) # 2.6 + _check_entry_dates_match_sheet(df, error_collector) # 2.6a (R-parity log) + df = _validate_entry_dates(df, error_collector) # 2.6b + df = _fill_product_names_and_sort(df) # 2.7 + df = _extract_balance_from_received(df) # 2.8 + df = _recode_na_units_to_zero(df) # 2.9 + df = _clean_received_from(df) # 2.10 + df = _clean_units_received(df, error_collector) # 2.11 + df = _recode_na_units_to_zero(df) # 2.12 + df = _remove_empty_data_rows(df) # 2.13 + df = _compute_balance_status(df) # 2.14 + df = _compute_running_balance(df) # 2.15 + + # 2.16 — type cast numeric/date columns via ErrorCollector; strip strings + # so trailing whitespace from openpyxl matches R's readxl trim-on-read. + # No string->Int intermediate needed (cf. patient pipeline's Int32-via-Float64 + # path): product unit columns are Float64 by schema, and the only Int columns + # (product_table_year/month) arrive from extraction as numeric, not strings. + schema = get_product_data_schema() + cast_targets = (pl.Int32, pl.Int64, pl.Float32, pl.Float64, pl.Date) + for col, target in schema.items(): + if col not in df.columns or target not in cast_targets: + continue + if df.schema[col] == target: + continue + df = safe_convert_column(df, col, target, error_collector) + + string_cols = [c for c in get_string_columns() if c in df.columns] + if string_cols: + df = df.with_columns([pl.col(c).str.strip_chars() for c in string_cols]) + + # 2.17 — drop helper index column added in 2.5. + if "index" in df.columns: + df = df.drop("index") + + df = _validate_negative_balances(df, error_collector) # 2.18 + df = _report_unknown_products(df, error_collector) # 2.19 + df = _add_product_categories(df) # 2.20 + df = _extract_unit_capacity(df) # 2.21 + # 2.22 cross-month combine happens at the table stage (S4-T1), not here. + + # Final schema conformance: guarantees 19 columns in schema order. + df = apply_schema(df) + # R-parity: UNIT_COLS treat absence as 0 (helper_product_data.R:292-297), + # so re-run the recode after schema seeding fills any newly-added column. + return _recode_na_units_to_zero(df) + + +def clean_product_file( + raw_parquet_path: Path, + output_parquet_path: Path, + error_collector: ErrorCollector | None = None, +) -> None: + """Clean a single product parquet file (I/O wrapper). + + Reads the raw parquet, runs ``clean_product_data``, and writes the result + to the output path. Mirrors ``clean_patient_file``. + + Args: + raw_parquet_path: Raw product parquet produced by extraction. + output_parquet_path: Destination for the cleaned parquet. + error_collector: Optional ErrorCollector; a new one is created if None. + """ + ec = error_collector if error_collector is not None else ErrorCollector() + df_raw = pl.read_parquet(raw_parquet_path) + df_clean = clean_product_data(df_raw, ec) + df_clean.write_parquet(output_parquet_path) + + +def _normalize_empty_strings_to_null(df: pl.DataFrame) -> pl.DataFrame: + """Step 2.0 — coerce "" / whitespace-only string cells to null. + + R's readxl returns NA for blank Excel cells; openpyxl returns an empty + string. Without this normalisation, rows whose only content is an empty + string survive `_remove_uninformative_rows` and `_remove_empty_data_rows` + (observed on 2024 CDA Dec24, which has one such row in + product_units_received). + """ + exprs = [ + pl.when(pl.col(c).str.strip_chars() == "") + .then(None) + .otherwise(pl.col(c)) + .alias(c) + for c, dtype in df.schema.items() + if dtype == pl.String + ] + return df.with_columns(exprs) if exprs else df + + +def _split_multi_product_cells(df: pl.DataFrame) -> pl.DataFrame: + """Step 2.1 — split cells containing multiple products. + + Splits the ``product`` column on "; " and " and ". Parenthesized unit + hints (e.g. "Product A (1 box)") are extracted into ``product_units_notes`` + and the numeric quantity is routed to ``product_units_received`` or + ``product_units_released`` based on which side has a counterparty. + """ + if "product" not in df.columns: + return df + + df = df.with_columns( + pl.col("product").cast(pl.Utf8).str.replace_all(" and ", "; ").alias("product") + ) + df = df.with_columns(pl.col("product").str.split("; ")).explode("product") + + if "product_units_notes" not in df.columns: + df = df.with_columns(pl.lit(None, dtype=pl.Utf8).alias("product_units_notes")) + else: + df = df.with_columns(pl.col("product_units_notes").cast(pl.Utf8)) + + for col in ("product_units_received", "product_units_released"): + if col in df.columns: + df = df.with_columns(pl.col(col).cast(pl.Utf8)) + else: + df = df.with_columns(pl.lit(None, dtype=pl.Utf8).alias(col)) + + product = pl.col("product").cast(pl.Utf8) + has_paren_and_kw = ( + product.str.contains(r"\(") + & product.str.contains(r"\)") + & (product.str.contains("box") | product.str.contains("unit")) + ) + no_slash = ~product.str.contains("/") + + paren_text = product.str.extract(r"\(([^()]+)\)", 1) + # Deliberate deviation from R (helper_product_data.R:579,594): R uses + # "[1-9]+" which drops the leading digit of counts containing 0 + # (e.g. "(10 box)" -> "1"). No real tracker row currently triggers + # box/unit extraction, but "\d+" is the correct regex. + number_str = paren_text.str.extract(r"(\d+)", 1) + + df = df.with_columns( + pl.when(has_paren_and_kw) + .then(paren_text) + .otherwise(pl.col("product_units_notes")) + .alias("product_units_notes") + ) + + if "product_received_from" in df.columns: + df = df.with_columns( + pl.when( + has_paren_and_kw + & no_slash + & pl.col("product_received_from").is_not_null() + ) + .then(number_str) + .otherwise(pl.col("product_units_received")) + .alias("product_units_received") + ) + + if "product_released_to" in df.columns: + df = df.with_columns( + pl.when( + has_paren_and_kw + & no_slash + & pl.col("product_released_to").is_not_null() + ) + .then(number_str) + .otherwise(pl.col("product_units_released")) + .alias("product_units_released") + ) + + return df + + +def _switch_misplaced_columns(df: pl.DataFrame) -> pl.DataFrame: + """Step 2.3 — swap ``product_units_received`` with ``product_received_from`` + for each sheet that contains "Remaining Stock" in units_received. + + Observed in 2018 PNG Nov/Dec trackers. R applies this rename inside its + per-sheet loop (read_product_data.R:577); Python's clean pipeline + operates on the whole-file DataFrame, so we scope the swap with .over() + to avoid corrupting clean sheets in the same file. + """ + if ( + "product_units_received" not in df.columns + or "product_received_from" not in df.columns + or "product_sheet_name" not in df.columns + ): + return df + + has_remaining_in_sheet = ( + pl.col("product_units_received") + .cast(pl.Utf8) + .str.contains("Remaining Stock") + .any() + .over("product_sheet_name") + ) + + if not df.select(has_remaining_in_sheet.any()).item(): + return df + + return df.with_columns( + pl.when(has_remaining_in_sheet) + .then(pl.col("product_received_from")) + .otherwise(pl.col("product_units_received")) + .alias("product_units_received"), + pl.when(has_remaining_in_sheet) + .then(pl.col("product_units_received")) + .otherwise(pl.col("product_received_from")) + .alias("product_received_from"), + ) + + +def _remove_uninformative_rows(df: pl.DataFrame) -> pl.DataFrame: + """Step 2.4 — drop rows where every activity column is null.""" + existing = [c for c in ACTIVITY_COLS if c in df.columns] + if not existing: + return df + all_null = pl.all_horizontal([pl.col(c).is_null() for c in existing]) + return df.filter(~all_null) + + +def _add_row_index(df: pl.DataFrame) -> pl.DataFrame: + """Step 2.5 — add a 1-based ``index`` column matching R's ``seq(1, nrow)``.""" + return df.with_row_index("index", offset=1) + + +def _format_dates( + df: pl.DataFrame, error_collector: ErrorCollector +) -> pl.DataFrame: + """Step 2.6 — parse ``product_entry_date`` with the flexible date parser. + + Two preprocessing steps run before delegating to ``parse_date_column``: + + 1. Strip a trailing time component (``" HH:MM[:SS]"``) from Excel + datetimes cast to string. Uses a precise regex rather than splitting + on the first space, so date strings that legitimately contain + spaces (e.g. ``"24 Feb 2020"``) survive intact. + 2. Normalize separator typos observed in the corpus (2026-04-29): + ``"--"`` → ``"-"``, period-then-letter → space, underscore → space. + Each rule is unambiguous and does not collide with valid date + formats. Product-scoped only; ``parse_date_flexible`` is shared + with the patient pipeline and is not modified. + """ + if "product_entry_date" not in df.columns: + return df + if df.schema["product_entry_date"] == pl.Date: + return df + df = df.with_columns( + pl.col("product_entry_date") + .cast(pl.Utf8) + .str.replace(r"\s+\d{1,2}:\d{2}(:\d{2})?$", "") + .str.replace_all(r"-{2,}", "-") + .str.replace_all(r"(\d)\.([A-Za-z])", r"$1 $2") + .str.replace_all(r"_", " ") + .alias("product_entry_date") + ) + return parse_date_column( + df, "product_entry_date", error_collector, patient_id_col="product" + ) + + +def _check_entry_dates_match_sheet( + df: pl.DataFrame, error_collector: ErrorCollector +) -> None: + """R-parity warning for entry dates that disagree with the sheet header. + + Mirrors R's ``check_entry_dates`` (read_product_data.R): one log entry + per row where the parsed ``product_entry_date`` doesn't match + ``(product_table_year, product_table_month)``. Sentinels, nulls, and + Buddhist-era dates are skipped. + + R filters to numeric Excel-serial cells before checking; Python checks + every parsed date because ``parse_date_flexible`` accepts both serials + and text. This is a deliberate parity-or-better expansion — the integration + diff harness tolerates the row-count drift. + + Side-effecting only: pushes log entries; never mutates ``df``. + """ + needed = {"product_entry_date", "product_table_year", "product_table_month"} + if not needed.issubset(df.columns): + return + + error_date = pl.lit(settings.error_val_date).str.to_date() + table_year = pl.col("product_table_year").cast(pl.Int32, strict=False) + table_month = pl.col("product_table_month").cast(pl.Int32, strict=False) + + is_real_date = ( + pl.col("product_entry_date").is_not_null() + & (pl.col("product_entry_date") != error_date) + & (pl.col("product_entry_date").dt.year() < BUDDHIST_ERA_THRESHOLD) + ) + mismatch = (pl.col("product_entry_date").dt.month() != table_month) | ( + pl.col("product_entry_date").dt.year() != table_year + ) + + select_cols = [ + "file_name" if "file_name" in df.columns else pl.lit("unknown").alias("file_name"), + "product" if "product" in df.columns else pl.lit("unknown").alias("product"), + "product_entry_date", + table_year.alias("_table_year"), + table_month.alias("_table_month"), + "product_sheet_name" if "product_sheet_name" in df.columns + else pl.lit("unknown").alias("product_sheet_name"), + ] + offenders = df.filter(is_real_date & mismatch).select(select_cols) + + for file_name, product, entry_date, ty, tm, sheet_name in offenders.iter_rows(): + error_collector.add_error( + file_name=file_name or "unknown", + patient_id=product or "unknown", + column="product_entry_date", + original_value=str(entry_date), + error_message=( + f"product_entry_date {entry_date} does not match sheet " + f"'{sheet_name or 'unknown'}' (expected {ty}-{tm:02d})" + ), + error_code="invalid_value", + function_name="check_entry_dates", + ) + + +def _validate_entry_dates( + df: pl.DataFrame, error_collector: ErrorCollector +) -> pl.DataFrame: + """Step 2.6b — flag fat-fingered Gregorian entry dates outside the tracker window. + + A row is flagged when its parsed Gregorian year falls outside + ``[product_table_year - YEAR_FLOOR_DELTA, product_table_year]`` and is below + ``BUDDHIST_ERA_THRESHOLD``. Two divergence patterns are caught, handled + asymmetrically: + + * **Above-max** (future-year typos, e.g. ``2099-03-15`` in a 2024 tracker): + logged AND replaced with ``error_val_date`` (9999-09-09). Future dates are + genuinely ambiguous (premature next-month entry vs typo); the sentinel + preserves that uncertainty signal in the output. + * **Below-min** (year-floor, e.g. ``1967-02-05`` in a 2024 tracker, or raw + cell ``29`` → 1900-01-29 from Excel-serial mis-coercion): logged ONLY; + the parsed date is preserved. Year-floor cells are unambiguously bad data + (1900-2014 in 2020+ trackers). R does not validate — leaving the parsed + date in place aligns the downstream sort/cumsum trajectory with R while + the audit log retains the data-quality flag. + + Above/below cases are logged with distinct messages so triage in + ``table_error_messages.parquet`` can distinguish them. + + Years at or beyond ``BUDDHIST_ERA_THRESHOLD`` (2400) are left untouched on both + branches so Buddhist-era dates (e.g. ``2567-11-11`` from Mandalay trackers) + flow through, and the parse-failure sentinel (9999-09-09) is not re-clobbered + or double-logged. This deliberately diverges from the patient pipeline's + ``_validate_dates``, which still clobbers any future date — patient is + out of scope for this change. + """ + if "product_entry_date" not in df.columns or "product_table_year" not in df.columns: + return df + + error_date = pl.lit(settings.error_val_date).str.to_date() + table_year = pl.col("product_table_year").cast(pl.Int32) + max_valid = pl.date(table_year, 12, 31) + min_valid = pl.date(table_year - YEAR_FLOOR_DELTA, 1, 1) + + not_buddhist = pl.col("product_entry_date").dt.year() < BUDDHIST_ERA_THRESHOLD + above_max_mask = ( + pl.col("product_entry_date").is_not_null() + & (pl.col("product_entry_date") > max_valid) + & not_buddhist + ) + below_min_mask = ( + pl.col("product_entry_date").is_not_null() + & (pl.col("product_entry_date") < min_valid) + & not_buddhist + ) + + above = df.filter(above_max_mask).select( + "file_name", "product", "product_entry_date", "product_table_year", "product_sheet_name" + ) + for file_name, product, entry_date, table_year_val, sheet_name in above.iter_rows(): + error_collector.add_error( + file_name=file_name or "unknown", + patient_id=product or "unknown", + column="product_entry_date", + original_value=str(entry_date), + error_message=( + f"product_entry_date {entry_date} beyond " + f"product_table_year {table_year_val} " + f"(sheet '{sheet_name or 'unknown'}')" + ), + error_code="invalid_value", + function_name="_validate_entry_dates", + ) + + below = df.filter(below_min_mask).select( + "file_name", "product", "product_entry_date", "product_table_year", "product_sheet_name" + ) + for file_name, product, entry_date, table_year_val, sheet_name in below.iter_rows(): + error_collector.add_error( + file_name=file_name or "unknown", + patient_id=product or "unknown", + column="product_entry_date", + original_value=str(entry_date), + error_message=( + f"product_entry_date {entry_date} before " + f"product_table_year {table_year_val} - {YEAR_FLOOR_DELTA} " + f"(sheet '{sheet_name or 'unknown'}')" + ), + error_code="invalid_value", + function_name="_validate_entry_dates", + ) + + return df.with_columns( + pl.when(above_max_mask) + .then(error_date) + .otherwise(pl.col("product_entry_date")) + .alias("product_entry_date") + ) + + +def _fill_product_names_and_sort(df: pl.DataFrame) -> pl.DataFrame: + """Step 2.7 — forward-fill ``product`` then sort inside each product group. + + R runs this inside a per-sheet for-loop, so forward_fill and the + first/last tier windows are scoped to (product, product_sheet_name). + Without that scoping a null product at the top of sheet N would pick + up the last product of sheet N-1, and a single product appearing in + multiple sheets would collapse to one start/end pair instead of one + per sheet. + + Within each (sheet, product) group the rank expression mirrors R's + (read_product_data.R:610-615): + rank = 1 if first row in group + = n + 2 if last row in group + = row_number if middle row with null date (preserves input order) + = dense_rank(date) + 1 if middle row with valid date + Stable sort on (product_table_month, product, _rank) reproduces R's + per-sheet for-loop + rbind output: sheet-major, product-minor. + + The parse-failure sentinel (``settings.error_val_date`` = 9999-09-09) + counts as null for rank purposes — R drops unparseable dates to NA, so + the input-order branch must catch sentinels too. Without this, sentinels + sort to the dense_d+1 end-of-changes position and corrupt the + cumulative-balance order vs. R. + """ + if "product" not in df.columns: + return df + if "index" not in df.columns: + raise KeyError( + "_fill_product_names_and_sort requires an 'index' column; run _add_row_index first" + ) + + group = ["product_sheet_name", "product"] if "product_sheet_name" in df.columns else ["product"] + sheet_group = ["product_sheet_name"] if "product_sheet_name" in df.columns else None + + if sheet_group is not None: + df = df.with_columns(pl.col("product").forward_fill().over(sheet_group)) + else: + df = df.with_columns(pl.col("product").forward_fill()) + + row_n = pl.col("index").rank("ordinal").over(group).cast(pl.Int64) + group_n = pl.col("index").count().over(group).cast(pl.Int64) + has_date = "product_entry_date" in df.columns + dense_d = ( + pl.col("product_entry_date").rank("dense").over(group).cast(pl.Int64) + if has_date + else pl.lit(None, dtype=pl.Int64) + ) + date_is_null = ( + ( + pl.col("product_entry_date").is_null() + | (pl.col("product_entry_date") == pl.lit(settings.error_val_date).str.to_date()) + ) + if has_date + else pl.lit(True) + ) + + rank_expr = ( + pl.when(row_n == 1) + .then(pl.lit(1, dtype=pl.Int64)) + .when(row_n == group_n) + .then(group_n + 2) + .when(date_is_null) + .then(row_n) + .otherwise(dense_d + 1) + .alias("_rank") + ) + + sort_cols: list[str] = [] + if "product_table_month" in df.columns: + sort_cols.append("product_table_month") + sort_cols.extend(["product", "_rank"]) + + return ( + df.with_columns(rank_expr) + .sort(sort_cols, nulls_last=True, maintain_order=True) + .drop("_rank") + ) + + +def _extract_balance_from_received(df: pl.DataFrame) -> pl.DataFrame: + """Step 2.8 — move stray balance values out of ``product_units_released``. + + For trackers (e.g. 2019 PKH, 2020 STH) where "Balance" appears in + ``product_units_received``, the value sitting in ``product_units_released`` + is relocated to ``product_received_from`` and released is cleared. + + R applies this rewrite inside its per-sheet loop (helper_product_data.R:329); + Python's clean pipeline operates on the whole-file DataFrame, so we scope + the trigger with .over("product_sheet_name") to avoid blanking + ``product_received_from`` on clean sheets that share a file with a sheet + using the "Balance" convention. + + Diverges from R (helper_product_data.R:329-335): R's case_when has no + default arm, so unmatched rows on a triggered sheet have their + ``product_received_from`` nulled. The trigger regex ``(?i)Balance`` + substring-matches "START BALANCE" / "END BALANCE" — present on every + standard sheet — so R's no-default behaviour silently wipes legitimate + supplier names (e.g. ``DKSH`` on Mahosot 2020 stock-receipt rows). + Python preserves the value via the ``.otherwise`` arm so audit-trail + data survives. Mutation of ``product_units_released`` is correspondingly + scoped to actual Balance rows (not all triggered rows) so that + preserving ``received_from`` on a non-Balance row does not collateral + a non-null ``released`` on the same row. + """ + required = ( + "product_units_received", + "product_units_released", + "product_received_from", + "product_sheet_name", + ) + if not all(c in df.columns for c in required): + return df + + balance_mask = pl.col("product_units_received").cast(pl.Utf8).str.contains("(?i)Balance") + sheet_triggered = ( + balance_mask.any().over("product_sheet_name") + & pl.col("product_received_from").is_null().any().over("product_sheet_name") + ) + + if not df.select(sheet_triggered.any()).item(): + return df + + # Materialize the trigger predicates before mutating product_received_from. + # `sheet_triggered` includes `received_from.is_null().any().over(sheet)`; if + # we let it re-evaluate in the second .with_columns() below, sheets where + # every row is a Balance marker (e.g. 2019 PKH Oct19) flip the trigger to + # False after the first call populates received_from on every row, and the + # released-clear pass becomes a no-op. Caching pins the pre-mutation truth. + df = df.with_columns( + sheet_triggered.alias("_sheet_triggered"), + balance_mask.alias("_balance_mask"), + ) + + # Two positive arms relocate the balance value into received_from on + # actual Balance-marker rows; the catch-all preserves received_from on + # all other rows. R's case_when has no default and would null those + # rows — see docstring. The "Total" arm before the catch-all nulls + # typist subtotal labels (e.g. "Accu-Chek Performa | Total | 35" + # subtotal rows in Penang DC / VNCH / Mandalay 2019 trackers); R nulled + # these implicitly via its no-default, and "Total" is never a real + # supplier — it's the label the typist put on the end-of-product-block + # subtotal row. + df = df.with_columns( + pl.when(pl.col("_sheet_triggered") & pl.col("_balance_mask") & pl.col("product_units_released").is_not_null()) + .then(pl.col("product_units_released").cast(pl.Utf8)) + .when(pl.col("_sheet_triggered") & pl.col("_balance_mask") & pl.col("product_received_from").is_not_null()) + .then(pl.col("product_received_from").cast(pl.Utf8)) + .when(pl.col("product_received_from") == "Total") + .then(pl.lit(None, dtype=pl.Utf8)) + .otherwise(pl.col("product_received_from")) + .alias("product_received_from") + ) + df = df.with_columns( + pl.when(pl.col("_sheet_triggered") & pl.col("_balance_mask") & pl.col("product_received_from").is_not_null()) + .then(pl.lit(None, dtype=pl.Utf8)) + .otherwise(pl.col("product_units_released")) + .alias("product_units_released") + ) + return df.drop("_sheet_triggered", "_balance_mask") + + +def _recode_na_units_to_zero(df: pl.DataFrame) -> pl.DataFrame: + """Steps 2.9 and 2.12 — fill NA in unit columns with 0. + + Applied twice: once before the string cleans in 2.10/2.11 and once after + to catch nulls introduced by those cleans. Respects column dtype — Utf8 + cols get "0", numeric cols get 0. + """ + for col in UNIT_COLS: + if col not in df.columns: + continue + fill: float | str = 0 if df.schema[col].is_numeric() else "0" + df = df.with_columns(pl.col(col).fill_null(fill)) + return df + + +def _clean_received_from(df: pl.DataFrame) -> pl.DataFrame: + """Step 2.10 — normalise the ``product_received_from`` column. + + If ``product_units_received`` contains "START", the corresponding + ``product_received_from`` value is copied to ``product_balance`` as the + start balance. Purely numeric entries (should be supplier names) are + removed. + """ + if "product_received_from" not in df.columns: + return df + + if "product_units_received" in df.columns: + # Preserve existing product_balance on non-START rows. Earlier code + # used .otherwise(pl.lit(None)) which only happened to be a no-op + # because nothing populates product_balance before this step; if a + # future step seeds it, that work would be silently wiped. + # .over("product_sheet_name") matches R's per-sheet loop semantics + # (helper_product_data.R:354-377) — defensive even though the current + # row-wise when/then/otherwise has no cross-row dependency. + if "product_balance" not in df.columns: + df = df.with_columns(pl.lit(None, dtype=pl.Float64).alias("product_balance")) + start_mask = pl.col("product_units_received").cast(pl.Utf8).str.contains("(?i)START") + df = df.with_columns( + pl.when(start_mask) + .then(pl.col("product_received_from").cast(pl.Utf8)) + .otherwise(pl.col("product_balance").cast(pl.Utf8)) + .over("product_sheet_name") + .alias("product_balance") + ) + + no_alpha = ~pl.col("product_received_from").cast(pl.Utf8).str.contains(r"[A-Za-z]") + df = df.with_columns( + pl.when(pl.col("product_received_from").is_not_null() & no_alpha) + .then(pl.lit(None)) + .otherwise(pl.col("product_received_from")) + .alias("product_received_from") + ) + return df + + +def _clean_units_received( + df: pl.DataFrame, + error_collector: ErrorCollector, +) -> pl.DataFrame: + """Step 2.11 — zero out balance markers and cast to numeric. + + Rows where ``product_units_received`` contains "START", "END" or + "BALANCE" are set to 0. Remaining values are cast to numeric; failures + yield null (then 0 via the second pass of step 2.12) and emit one + ``type_conversion`` entry per row to ``error_collector`` — R-parity with + ``script3_create_table_product_data.R::preparing_product_fields``'s + ``invalid_value`` warnings. + """ + if "product_units_received" not in df.columns: + return df + + marker_mask = pl.col("product_units_received").cast(pl.Utf8).str.contains( + "(?i)START|END|BALANCE" + ) + casted = pl.col("product_units_received").cast(pl.Float64, strict=False) + + failure_mask = ( + pl.col("product_units_received").is_not_null() & ~marker_mask & casted.is_null() + ) + failures = df.filter(failure_mask) + for row in failures.iter_rows(named=True): + original = row["product_units_received"] + error_collector.add_error( + file_name=row.get("file_name") or "unknown", + patient_id=row.get("product") or "unknown", + column="product_units_received", + error_code="type_conversion", + function_name="_clean_units_received", + original_value=str(original), + error_message=( + f"product_units_received '{original}' could not be converted to numeric" + ), + ) + + return df.with_columns( + pl.when(marker_mask).then(pl.lit(0.0)).otherwise(casted).alias("product_units_received") + ) + + +def _remove_empty_data_rows(df: pl.DataFrame) -> pl.DataFrame: + """Step 2.13 — drop rows with no meaningful transaction data. + + Removes rows where units_received, units_released, units_returned, + released_to, entry_date and balance are all null. + """ + existing = [c for c in EMPTY_ROW_COLS if c in df.columns] + if not existing: + return df + all_null = pl.all_horizontal([pl.col(c).is_null() for c in existing]) + return df.filter(~all_null) + + +def _compute_balance_status(df: pl.DataFrame) -> pl.DataFrame: + """Step 2.14 — label each row ``start`` / ``change`` / ``end`` per (product, sheet). + + First row in each (product, product_sheet_name) group is ``start``, last + is ``end``, all others are ``change``. R processes this per-sheet, so a + product that appears in 12 sheets has 12 start rows and 12 end rows — + not a single start/end across the whole tracker. + """ + if "index" not in df.columns: + raise KeyError( + "_compute_balance_status requires an 'index' column; run _add_row_index first" + ) + + group = ["product_sheet_name", "product"] if "product_sheet_name" in df.columns else ["product"] + + return df.with_columns( + pl.when(pl.col("index") == pl.col("index").first().over(group)) + .then(pl.lit("start")) + .when(pl.col("index") == pl.col("index").last().over(group)) + .then(pl.lit("end")) + .otherwise(pl.lit("change")) + .alias("product_balance_status") + ) + + +def _compute_running_balance(df: pl.DataFrame) -> pl.DataFrame: + """Step 2.15 — compute the running stock balance per product. + + Formula: ``balance[i] = balance[i-1] - released[i] + received[i]``. + For year >= 2021 the ``end`` rows' received/released are zeroed first + (they are summary rows, not real transactions). The year is read from + the ``product_table_year`` column (set by extraction). + + Implementation uses a vectorized cumsum per product instead of R's + iterative loop. Equivalent when step 2.12 has filled unit-column nulls + with 0 and step 2.7's sort is stable — corrupt upstream state could + diverge (R's loop propagates NA forward; cumsum treats it as 0 after + 2.12's fill). + + Preconditions: ``_fill_product_names_and_sort`` (step 2.7) and + ``_compute_balance_status`` (step 2.14) must have run. The first + positional row in each (sheet, product) group must be labeled + ``"start"`` because the cumsum below seeds from + ``product_balance.first().over(group)``. + """ + group = ( + ["product_sheet_name", "product"] + if "product_sheet_name" in df.columns + else ["product"] + ) + + if "product_balance_status" not in df.columns: + raise RuntimeError( + "_compute_running_balance requires product_balance_status; " + "run _compute_balance_status (step 2.14) first" + ) + if df.height > 0: + first_status = df.select( + pl.col("product_balance_status").first().over(group).alias("_fs") + )["_fs"] + if not (first_status == "start").all(): + bad_groups = ( + df.with_columns( + pl.col("product_balance_status").first().over(group).alias("_fs") + ) + .filter(pl.col("_fs") != "start") + .select(group) + .unique() + .height + ) + raise RuntimeError( + f"_compute_running_balance precondition violated: {bad_groups} " + "(sheet, product) groups have a non-'start' row first. " + "Did _fill_product_names_and_sort (step 2.7) run?" + ) + + casts = [ + pl.col("product_balance").cast(pl.Float64, strict=False), + pl.col("product_units_released").cast(pl.Float64, strict=False), + ] + if "product_units_returned" in df.columns: + casts.append(pl.col("product_units_returned").cast(pl.Float64, strict=False)) + df = df.with_columns(casts) + + if "product_table_year" in df.columns: + year_gate = ( + pl.col("product_table_year").is_not_null() + & (pl.col("product_table_year") >= 2021) + & (pl.col("product_balance_status") == "end") + ) + df = df.with_columns( + pl.when(year_gate) + .then(pl.lit(0.0)) + .otherwise(pl.col("product_units_received")) + .alias("product_units_received"), + pl.when(year_gate) + .then(pl.lit(0.0)) + .otherwise(pl.col("product_units_released")) + .alias("product_units_released"), + ) + + df = df.with_columns( + pl.when( + (pl.col("product_balance_status") == "start") + & pl.col("product_balance").is_null() + ) + .then(pl.col("product_units_received") - pl.col("product_units_released")) + .otherwise(pl.col("product_balance")) + .alias("product_balance") + ) + + is_start = pl.col("product_balance_status") == "start" + delta = ( + pl.when(is_start) + .then(pl.lit(0.0)) + .otherwise(pl.col("product_units_received") - pl.col("product_units_released")) + ) + df = df.with_columns( + ( + pl.col("product_balance").first().over(group) + + delta.cum_sum().over(group) + ) + .round(10) + .alias("product_balance") + ) + return df + + +def _validate_negative_balances( + df: pl.DataFrame, + error_collector: ErrorCollector, +) -> pl.DataFrame: + """Step 2.18 — log rows with a negative ``product_balance``. + + Data is not modified; each negative-balance row is reported via the + error collector for downstream investigation. + """ + if "product_balance" not in df.columns: + return df + + negatives = df.filter( + pl.col("product_balance").is_not_null() & (pl.col("product_balance") < 0) + ).select("file_name", "product_balance", "product", "product_sheet_name") + for file_name, balance, product, sheet_name in negatives.iter_rows(): + error_collector.add_error( + file_name=file_name or "unknown", + patient_id="unknown", + column="product_balance", + original_value=balance, + error_message=( + f"Negative balance {balance} for product " + f"'{product or 'unknown'}' in sheet " + f"'{sheet_name or 'unknown'}'" + ), + error_code="invalid_value", + function_name="_validate_negative_balances", + ) + return df + + +def _report_unknown_products( + df: pl.DataFrame, + error_collector: ErrorCollector, +) -> pl.DataFrame: + """Step 2.19 — flag product names missing from the Stock_Summary reference. + + Case-insensitive comparison against the known product list. DataFrame is + returned unchanged; violations are logged to the error collector. + """ + if "product" not in df.columns: + return df + + known = set(load_known_products()) + + # R logs unknowns per-sheet; replicate by keying errors on + # (file_name, product_sheet_name, product) triples. + cols = [ + c for c in ("file_name", "product_sheet_name", "product") if c in df.columns + ] + unknowns = ( + df.filter(pl.col("product").is_not_null()) + .with_columns(pl.col("product").str.to_lowercase().alias("_lower")) + .filter(~pl.col("_lower").is_in(list(known))) + .select(cols) + .unique() + ) + for row in unknowns.iter_rows(named=True): + error_collector.add_error( + file_name=row.get("file_name") or "unknown", + patient_id="unknown", + column="product", + original_value=row["product"], + error_message=( + f"Unknown product '{row['product']}' in sheet " + f"'{row.get('product_sheet_name') or 'unknown'}'" + ), + error_code="invalid_value", + function_name="_report_unknown_products", + ) + return df + + +def _add_product_categories(df: pl.DataFrame) -> pl.DataFrame: + """Step 2.20 — left-join ``product_category`` from Stock_Summary. + + Products not present in the reference receive a null category. + """ + if "product" not in df.columns: + return df + + # columns: product (lowercased), product_category + categories = load_product_categories() + + df = df.with_columns( + pl.col("product").cast(pl.Utf8).str.to_lowercase().alias("_product_join") + ) + df = df.join( + categories.rename({"product": "_product_join"}), + on="_product_join", + how="left", + ) + return df.drop("_product_join") + + +def _extract_unit_capacity(df: pl.DataFrame) -> pl.DataFrame: + """Step 2.21 — parse units-per-package from the product name. + + Reads parenthesised hints such as ``(2s)`` or ``(10's)`` into + ``product_unit_capacity``. ``singles`` maps to 1; absent hints default + to 1 as well. + """ + if "product" not in df.columns: + return df.with_columns( + pl.lit(1, dtype=pl.Int32).alias("product_unit_capacity") + ) + + paren = pl.col("product").cast(pl.Utf8).str.extract(r"\(([^()]+)\)", 1) + paren_normalized = ( + pl.when(paren.str.contains("(?i)singles")) + .then(pl.lit("1s")) + .otherwise(paren) + ) + # Digits immediately followed by an optional apostrophe and then "s" + # (covers "(10s)", "(5's)"). Non-matching paren content → null → 1. + digits = paren_normalized.str.extract(r"(\d+)'?s", 1) + capacity = digits.cast(pl.Int32, strict=False).fill_null(1) + + return df.with_columns(capacity.alias("product_unit_capacity")) diff --git a/src/a4d/clean/schema.py b/src/a4d/clean/schema.py index 3748ce1..64b09ff 100644 --- a/src/a4d/clean/schema.py +++ b/src/a4d/clean/schema.py @@ -103,10 +103,10 @@ def apply_schema(df: pl.DataFrame) -> pl.DataFrame: """Apply the meta schema to a DataFrame. This function: - 1. Adds missing columns with NULL values - 2. Casts existing columns to target types (if they exist) - 3. Reorders columns to match schema order - 4. Returns a DataFrame with the exact schema + 1. Adds missing columns with NULL values typed per the schema. + 2. Reorders columns to match schema order. + + Casting is the caller's responsibility (see ``safe_convert_column``). Args: df: Input DataFrame (may be missing columns) diff --git a/src/a4d/clean/schema_product.py b/src/a4d/clean/schema_product.py new file mode 100644 index 0000000..360a42b --- /dev/null +++ b/src/a4d/clean/schema_product.py @@ -0,0 +1,73 @@ +"""Meta schema definition for product data - matches R pipeline's preparing_product_fields().""" + +import polars as pl + + +def get_product_data_schema() -> dict[str, type[pl.DataType] | pl.DataType]: + """Get the complete meta schema for product data. + + This schema matches the R pipeline's preparing_product_fields() in + script3_create_table_product_data.R. Column order matches R's field list. + + Returns: + Dictionary mapping column names to Polars data types + """ + return { + "product": pl.String, # character() in R + "product_units_notes": pl.String, + "product_entry_date": pl.Date, # Date in R + "product_units_released": pl.Float64, # numeric() in R + "product_released_to": pl.String, + "product_units_received": pl.Float64, + "product_received_from": pl.String, + "product_balance": pl.Float64, + "product_units_returned": pl.Float64, + "product_returned_by": pl.String, + "product_table_month": pl.Int32, # integer() in R + "product_table_year": pl.Int32, + "product_sheet_name": pl.String, + "file_name": pl.String, + "product_balance_status": pl.String, + "product_category": pl.String, + "orig_product_released_to": pl.String, + "product_unit_capacity": pl.Int32, + "product_remarks": pl.String, + "clinic_id": pl.String, + } + + +def apply_schema(df: pl.DataFrame) -> pl.DataFrame: + """Apply the meta schema to a DataFrame. + + This function: + 1. Adds missing columns with NULL values typed per the schema. + 2. Reorders columns to match schema order. + + Casting is the caller's responsibility (see ``safe_convert_column``). + + Args: + df: Input DataFrame (may be missing columns) + + Returns: + DataFrame with complete schema applied + """ + schema = get_product_data_schema() + + # Start with existing columns + df_result = df + + # Add missing columns with NULL values + missing_cols = set(schema.keys()) - set(df.columns) + for col in missing_cols: + df_result = df_result.with_columns(pl.lit(None, dtype=schema[col]).alias(col)) + + # Reorder columns to match schema order + df_result = df_result.select(list(schema.keys())) + + return df_result + + +def get_string_columns() -> list[str]: + """Get list of string columns from schema.""" + schema = get_product_data_schema() + return [col for col, dtype in schema.items() if dtype == pl.String] diff --git a/src/a4d/clean/validators.py b/src/a4d/clean/validators.py index f279d52..9e8c4d1 100644 --- a/src/a4d/clean/validators.py +++ b/src/a4d/clean/validators.py @@ -66,6 +66,16 @@ def load_validation_rules() -> dict[str, Any]: return load_yaml(yaml_path) +def load_numeric_ranges() -> dict[str, dict[str, float]]: + """Load the ``numeric_ranges`` block from validation_rules.yaml. + + Consumed by the source-vs-output validator. Mirrors the hardcoded + thresholds in clean/patient.py; see the YAML comment for the drift caveat. + """ + rules = load_validation_rules() + return rules.get("numeric_ranges", {}) + + def validate_allowed_values( df: pl.DataFrame, column: str, @@ -74,6 +84,7 @@ def validate_allowed_values( replace_invalid: bool = True, file_name_col: str = "file_name", patient_id_col: str = "patient_id", + allow_csv_subset: bool = False, ) -> pl.DataFrame: """Validate column against allowed values with case-insensitive matching. @@ -131,6 +142,35 @@ def validate_allowed_values( if sanitized in canonical_mapping: # Valid - replace with canonical value value_replacements[original_val] = canonical_mapping[sanitized] + elif allow_csv_subset and "," in original_val: + # e.g. insulin_subtype "pre-mixed,rapid-acting" is valid if every + # token is in allowed_values. Emit canonical-case CSV. + parts = [p.strip() for p in original_val.split(",") if p.strip()] + canonical_parts = [] + all_matched = bool(parts) + for part in parts: + part_sanitized = sanitize_str(part) + if part_sanitized in canonical_mapping: + canonical_parts.append(canonical_mapping[part_sanitized]) + else: + all_matched = False + break + if all_matched: + value_replacements[original_val] = ",".join(canonical_parts) + continue + # Fall through to the invalid branch below. + error_collector.add_error( + file_name="unknown", + patient_id="unknown", + column=column, + original_value=original_val, + error_message=f"Value '{original_val}' not in allowed values (CSV-subset check): {allowed_values}", + error_code="invalid_value", + function_name="validate_allowed_values", + ) + value_replacements[original_val] = ( + settings.error_val_character if replace_invalid else original_val + ) else: # Invalid - log error error_collector.add_error( @@ -198,6 +238,7 @@ def validate_column_from_rules( # Extract validation parameters from simplified rules allowed_values = rules.get("allowed_values", []) replace_invalid = rules.get("replace_invalid", True) + allow_csv_subset = rules.get("allow_csv_subset", False) df = validate_allowed_values( df=df, @@ -207,6 +248,7 @@ def validate_column_from_rules( replace_invalid=replace_invalid, file_name_col=file_name_col, patient_id_col=patient_id_col, + allow_csv_subset=allow_csv_subset, ) return df diff --git a/src/a4d/cli.py b/src/a4d/cli.py index c4b0a96..4786260 100644 --- a/src/a4d/cli.py +++ b/src/a4d/cli.py @@ -15,6 +15,8 @@ process_patient_tables, run_patient_pipeline, ) +from a4d.pipeline.product import process_product_tables, run_product_pipeline +from a4d.state import filter_unchanged_trackers, load_previous_manifest from a4d.tables.logs import create_table_logs # google-crc32c has no pre-built C wheel for Python 3.14 yet; the pure-Python @@ -46,8 +48,8 @@ def _display_tables_summary(tables: dict[str, Path]) -> None: tables_table.add_column("Path", style="green") tables_table.add_column("Records", justify="right", style="magenta") - # Add patient tables first, then logs table - for name in ["static", "monthly", "annual"]: + # Add patient tables first, then product, then logs table + for name in ["static", "monthly", "annual", "product_data"]: if name in tables: path = tables[name] try: @@ -71,6 +73,147 @@ def _display_tables_summary(tables: dict[str, Path]) -> None: console.print() +def _render_pipeline_header( + data_root: str, + output_root: str | Path, + workers: int, + *, + skip_tables: bool = False, + extras: list[tuple[str, str]] | None = None, +) -> None: + """Render the per-command header banner. + + `extras` carries the run-pipeline-only fields (Project / Dataset / Drive / + Download / Upload / Product) so process-patient and process-product can + omit them. All labels are padded to a 13-character column to match the + pre-refactor output exactly. + """ + rows: list[tuple[str, str]] = [ + ("Data root", str(data_root)), + ("Output root", str(output_root)), + ("Workers", str(workers)), + ] + if skip_tables: + rows.append(("Tables", "skipped")) + if extras: + rows.extend(extras) + for label, value in rows: + console.print(f"{label + ':':<13}{value}") + console.print() + + +def _render_pipeline_results_summary( + result, + tables: dict[str, Path], + total_errors: int, + files_with_errors: int, +) -> None: + """Render the 7-row Summary table used by process-patient / process-product.""" + summary_table = Table(title="Summary") + summary_table.add_column("Metric", style="cyan") + summary_table.add_column("Value", style="green") + + summary_table.add_row("Total Trackers", str(result.total_trackers)) + summary_table.add_row("Successful", str(result.successful_trackers)) + summary_table.add_row("Failed", str(result.failed_trackers)) + summary_table.add_row("Tables Created", str(len(tables))) + summary_table.add_row("", "") + summary_table.add_row("Data Quality Errors", f"{total_errors:,}") + summary_table.add_row("Files with Errors", str(files_with_errors)) + + console.print(summary_table) + + +def _resolve_tracker_files( + file: Path | None, + data_root_arg: Path | None, + incremental: bool, + output_root: Path, +) -> tuple[list[Path] | None, str]: + """Resolve the tracker-file list for a CLI invocation. + + Returns ``(tracker_files, display_str)``. ``tracker_files`` is ``None`` when + the orchestrator should discover trackers itself (the default + non-incremental "process everything in data_root" path). An empty list means + discovery + incremental filtering produced no work; the caller should + short-circuit. + + --file always wins; --incremental + --file is a no-op (logged warning), + matching the design that single-file is an explicit user override. + """ + if file: + if incremental: + console.print( + "[yellow]Warning: --incremental is ignored when --file is set[/yellow]" + ) + return [file], f"{file} (single file)" + + from a4d.config import settings as _settings + + if data_root_arg is not None: + files = discover_tracker_files(data_root_arg) + if not files: + console.print( + f"[bold red]Error: No tracker files found in {data_root_arg}[/bold red]\n" + ) + raise typer.Exit(1) + display = str(data_root_arg) + elif incremental: + files = discover_tracker_files(_settings.data_root) + display = str(_settings.data_root) + else: + # Default: orchestrator discovers everything from settings.data_root. + return None, str(_settings.data_root) + + if incremental: + manifest = load_previous_manifest(output_root) + files, summary = filter_unchanged_trackers(files, manifest) + console.print( + f"[cyan]Incremental filter: queued {summary.queued}, " + f"skipped {summary.skipped} unchanged " + f"(new={summary.new}, changed={summary.changed}, " + f"incomplete={summary.previously_incomplete})[/cyan]\n" + ) + + return files, display + + +def _render_failed_trackers( + result, + *, + mode: str, + truncate: int | None = 100, + title: str = "Failed Trackers", + leading_newline: bool = True, +) -> None: + """Render the failed-trackers section. + + `mode="table"` matches process-patient / process-product (Rich Table, + error truncated). `mode="bullets"` matches run-pipeline (bullet list, + full error). `truncate` is ignored in bullets mode. + """ + if result.failed_trackers <= 0: + return + prefix = "\n" if leading_newline else "" + console.print(f"{prefix}[bold yellow]{title}:[/bold yellow]") + if mode == "table": + failed_table = Table() + failed_table.add_column("File", style="red") + failed_table.add_column("Error") + for tr in result.tracker_results: + if not tr.success: + error_text = str(tr.error) + if truncate is not None: + error_text = error_text[:truncate] + failed_table.add_row(tr.tracker_file.name, error_text) + console.print(failed_table) + else: # bullets + for tr in result.tracker_results: + if not tr.success: + console.print(f" • {tr.tracker_file.name}: {tr.error}") + console.print() + + @app.command("process-patient") def process_patient_cmd( file: Annotated[ @@ -90,9 +233,6 @@ def process_patient_cmd( skip_tables: Annotated[ bool, typer.Option("--skip-tables", help="Skip table creation (only extract + clean)") ] = False, - force: Annotated[ - bool, typer.Option("--force", help="Force reprocessing (ignore existing outputs)") - ] = False, data_root: Annotated[ Path | None, typer.Option( @@ -102,12 +242,36 @@ def process_patient_cmd( output_root: Annotated[ Path | None, typer.Option("--output", "-o", help="Output directory (default: from config)") ] = None, + incremental: Annotated[ + bool, + typer.Option( + "--incremental", + help=( + "Skip trackers whose MD5 + completion state match the previous " + "run's manifest. Preserves prior outputs (clean_output disabled)." + ), + ), + ] = False, + force: Annotated[ + bool, + typer.Option( + "--force", + help=( + "Wipe prior outputs and reprocess every tracker. Same as the " + "default behavior; pass explicitly for self-documenting deploy " + "commands. Overrides --incremental if both are passed." + ), + ), + ] = False, ): """Process patient data pipeline. \b - Output is always cleaned before each run so tables reflect only the - current run's files. + By default, output is cleaned before each run so tables reflect only the + current run's files. With --incremental, prior outputs are preserved and + only new/changed/previously-incomplete trackers are re-processed. + With --force, behaves as the default (wipe + reprocess) and overrides + --incremental if both are passed. Examples: # Process all trackers in data_root (from config) @@ -124,35 +288,39 @@ def process_patient_cmd( # Just extract + clean, skip tables uv run a4d process-patient --skip-tables + + # Skip trackers whose MD5 matches the previous run's manifest + uv run a4d process-patient --incremental + + # Explicitly wipe outputs and reprocess everything + uv run a4d process-patient --force """ from a4d.config import settings as _settings console.print("\n[bold blue]A4D Patient Pipeline[/bold blue]\n") - if file: - tracker_files = [file] - data_root_display = f"{file} (single file)" - elif data_root: - tracker_files = discover_tracker_files(data_root) - if not tracker_files: - console.print(f"[bold red]Error: No tracker files found in {data_root}[/bold red]\n") - raise typer.Exit(1) - data_root_display = str(data_root) - else: - tracker_files = None # pipeline uses settings.data_root - data_root_display = str(_settings.data_root) + if force and incremental: + console.print( + "[yellow]Warning: --incremental is ignored when --force is set[/yellow]" + ) + incremental = False _output_root = output_root or _settings.output_root _workers = workers if workers is not None else _settings.max_workers - console.print(f"Data root: {data_root_display}") - console.print(f"Output root: {_output_root}") - console.print(f"Workers: {_workers}") - if skip_tables: - console.print("Tables: skipped") - if force: - console.print("Force: yes") - console.print() + tracker_files, data_root_display = _resolve_tracker_files( + file, data_root, incremental, _output_root + ) + + if tracker_files is not None and len(tracker_files) == 0: + console.print( + "[bold green]✓ No trackers need reprocessing — exiting[/bold green]\n" + ) + raise typer.Exit(0) + + _render_pipeline_header( + data_root_display, _output_root, _workers, skip_tables=skip_tables + ) # Step 1: Extract + clean (table creation handled below for visible progress) console.print("[bold]Step 1/3:[/bold] Extracting and cleaning tracker files...") @@ -162,8 +330,7 @@ def process_patient_cmd( max_workers=_workers, output_root=output_root, skip_tables=True, # tables created below with console feedback - force=force, - clean_output=True, + clean_output=force or not incremental, # incremental keeps prior outputs; --force always wipes show_progress=True, console_log_level="ERROR", ) @@ -201,19 +368,7 @@ def process_patient_cmd( total_errors = sum(tr.cleaning_errors for tr in result.tracker_results) files_with_errors = sum(1 for tr in result.tracker_results if tr.cleaning_errors > 0) - summary_table = Table(title="Summary") - summary_table.add_column("Metric", style="cyan") - summary_table.add_column("Value", style="green") - - summary_table.add_row("Total Trackers", str(result.total_trackers)) - summary_table.add_row("Successful", str(result.successful_trackers)) - summary_table.add_row("Failed", str(result.failed_trackers)) - summary_table.add_row("Tables Created", str(len(tables))) - summary_table.add_row("", "") # Spacer - summary_table.add_row("Data Quality Errors", f"{total_errors:,}") - summary_table.add_row("Files with Errors", str(files_with_errors)) - - console.print(summary_table) + _render_pipeline_results_summary(result, tables, total_errors, files_with_errors) # Show error type breakdown if there are errors if total_errors > 0: @@ -241,21 +396,7 @@ def process_patient_cmd( console.print(error_type_table) - # Show failed trackers if any - if result.failed_trackers > 0: - console.print("\n[bold yellow]Failed Trackers:[/bold yellow]") - failed_table = Table() - failed_table.add_column("File", style="red") - failed_table.add_column("Error") - - for tr in result.tracker_results: - if not tr.success: - failed_table.add_row( - tr.tracker_file.name, - str(tr.error)[:100], # Truncate long errors - ) - - console.print(failed_table) + _render_failed_trackers(result, mode="table") # Show top files with most data quality errors (if any) if total_errors > 0: @@ -344,7 +485,9 @@ def create_tables_cmd( console.print(f"Found {len(cleaned_files)} cleaned parquet files\n") try: + from a4d.config import settings from a4d.tables.clinic import create_table_clinic_static + from a4d.tables.metadata import create_table_tracker_metadata console.print("[bold]Creating tables...[/bold]") @@ -365,6 +508,21 @@ def create_tables_cmd( clinic_table_path = create_table_clinic_static(output_dir) tables["clinic_data_static"] = clinic_table_path + # Create tracker metadata table (MD5 + per-tracker output presence). + # Skipped if settings.data_root is unreachable — the table needs the + # raw .xlsx files, which create-tables doesn't otherwise require. + if settings.data_root.exists(): + console.print(" • Creating tracker metadata table...") + metadata_path = create_table_tracker_metadata( + settings.data_root, input_dir.parent + ) + tables["tracker_metadata"] = metadata_path + else: + console.print( + f" [yellow]Warning: data_root {settings.data_root} not found, " + "skipping tracker metadata[/yellow]" + ) + # Display results console.print("\n[bold green]✓ Tables created successfully![/bold green]") _display_tables_summary(tables) @@ -374,6 +532,212 @@ def create_tables_cmd( raise typer.Exit(1) from e +@app.command("process-product") +def process_product_cmd( + file: Annotated[ + Path | None, + typer.Option( + "--file", + "-f", + help="Process specific tracker file (if not set, processes all files in data_root)", + ), + ] = None, + workers: Annotated[ + int | None, + typer.Option( + "--workers", "-w", help="Number of parallel workers (default: A4D_MAX_WORKERS)" + ), + ] = None, + skip_tables: Annotated[ + bool, typer.Option("--skip-tables", help="Skip table creation (only extract + clean)") + ] = False, + data_root: Annotated[ + Path | None, + typer.Option( + "--data-root", "-d", help="Directory containing tracker files (default: from config)" + ), + ] = None, + output_root: Annotated[ + Path | None, typer.Option("--output", "-o", help="Output directory (default: from config)") + ] = None, + incremental: Annotated[ + bool, + typer.Option( + "--incremental", + help=( + "Skip trackers whose MD5 + completion state match the previous " + "run's manifest. Preserves prior outputs (clean_output disabled)." + ), + ), + ] = False, + force: Annotated[ + bool, + typer.Option( + "--force", + help=( + "Wipe prior outputs and reprocess every tracker. Same as the " + "default behavior; pass explicitly for self-documenting deploy " + "commands. Overrides --incremental if both are passed." + ), + ), + ] = False, +): + """Process product data pipeline. + + \b + By default, output is cleaned before each run so tables reflect only the + current run's files. With --incremental, prior outputs are preserved and + only new/changed/previously-incomplete trackers are re-processed. + With --force, behaves as the default (wipe + reprocess) and overrides + --incremental if both are passed. + + Examples: + # Process all trackers in data_root (from config) + uv run a4d process-product + + # Process specific file + uv run a4d process-product --file /path/to/tracker.xlsx + + # Parallel processing with 8 workers + uv run a4d process-product --workers 8 + + # Just extract + clean, skip tables + uv run a4d process-product --skip-tables + + # Skip trackers whose MD5 matches the previous run's manifest + uv run a4d process-product --incremental + + # Explicitly wipe outputs and reprocess everything + uv run a4d process-product --force + """ + from a4d.config import settings as _settings + + console.print("\n[bold blue]A4D Product Pipeline[/bold blue]\n") + + if force and incremental: + console.print( + "[yellow]Warning: --incremental is ignored when --force is set[/yellow]" + ) + incremental = False + + _output_root = output_root or _settings.output_root + _workers = workers if workers is not None else _settings.max_workers + + tracker_files, data_root_display = _resolve_tracker_files( + file, data_root, incremental, _output_root + ) + + if tracker_files is not None and len(tracker_files) == 0: + console.print( + "[bold green]✓ No trackers need reprocessing — exiting[/bold green]\n" + ) + raise typer.Exit(0) + + _render_pipeline_header( + data_root_display, _output_root, _workers, skip_tables=skip_tables + ) + + console.print("[bold]Step 1/2:[/bold] Extracting and cleaning product data...") + try: + result = run_product_pipeline( + tracker_files=tracker_files, + max_workers=_workers, + output_root=output_root, + skip_tables=True, + clean_output=force or not incremental, # incremental keeps prior outputs; --force always wipes + show_progress=True, + console_log_level="ERROR", + ) + except Exception as e: + console.print(f"\n[bold red]Error: {e}[/bold red]\n") + raise typer.Exit(1) from e + + tables: dict[str, Path] = {} + if not skip_tables and result.successful_trackers > 0: + cleaned_dir = _output_root / "product_data_cleaned" + tables_dir = _output_root / "tables" + + console.print("[bold]Step 2/2:[/bold] Creating product table...") + try: + tables = process_product_tables(cleaned_dir, tables_dir) + except Exception as e: + console.print(f"[bold red]Error creating tables: {e}[/bold red]") + elif skip_tables: + console.print("[dim]Step 2: Skipped (--skip-tables)[/dim]") + + console.print("\n[bold]Pipeline Results[/bold]\n") + + total_errors = sum(tr.cleaning_errors for tr in result.tracker_results) + files_with_errors = sum(1 for tr in result.tracker_results if tr.cleaning_errors > 0) + + _render_pipeline_results_summary(result, tables, total_errors, files_with_errors) + + _render_failed_trackers(result, mode="table") + + _display_tables_summary(tables) + + if result.success: + console.print("\n[bold green]✓ Product pipeline completed successfully![/bold green]\n") + raise typer.Exit(0) + else: + console.print( + f"\n[bold red]✗ Product pipeline completed with {result.failed_trackers} failures[/bold red]\n" + ) + raise typer.Exit(1) + + +@app.command("create-product-tables") +def create_product_tables_cmd( + input_dir: Annotated[ + Path, + typer.Option("--input", "-i", help="Directory containing cleaned product parquet files"), + ], + output_dir: Annotated[ + Path | None, + typer.Option( + "--output", "-o", help="Output directory for tables (default: input_dir/tables)" + ), + ] = None, +): + """Create the product table from existing cleaned parquet files. + + \b + Examples: + # Create table from existing output + uv run a4d create-product-tables --input output/product_data_cleaned + + # Specify custom output directory + uv run a4d create-product-tables --input output/product_data_cleaned --output custom_tables + """ + console.print("\n[bold blue]A4D Product Table Creation[/bold blue]\n") + + if output_dir is None: + output_dir = input_dir.parent / "tables" + + console.print(f"Input directory: {input_dir}") + console.print(f"Output directory: {output_dir}\n") + + cleaned_files = list(input_dir.glob("*_product_cleaned.parquet")) + if not cleaned_files: + console.print( + f"[bold red]Error: No cleaned product parquet files found in {input_dir}[/bold red]\n" + ) + raise typer.Exit(1) + + console.print(f"Found {len(cleaned_files)} cleaned product parquet files\n") + + try: + console.print("[bold]Creating product table...[/bold]") + tables = process_product_tables(input_dir, output_dir) + + console.print("\n[bold green]✓ Product table created successfully![/bold green]") + _display_tables_summary(tables) + + except Exception as e: + console.print(f"\n[bold red]Error creating product table: {e}[/bold red]\n") + raise typer.Exit(1) from e + + @app.command("upload-tables") def upload_tables_cmd( tables_dir: Annotated[ @@ -563,9 +927,6 @@ def run_pipeline_cmd( "--workers", "-w", help="Number of parallel workers (default: A4D_MAX_WORKERS)" ), ] = None, - force: Annotated[ - bool, typer.Option("--force", help="Force reprocessing (ignore existing outputs)") - ] = False, skip_download: Annotated[ bool, typer.Option("--skip-download", help="Skip GCS download (use files already in data_root)"), @@ -581,6 +942,43 @@ def run_pipeline_cmd( help="Skip Google Drive download of reference data (clinic_data.xlsx)", ), ] = False, + skip_product: Annotated[ + bool, + typer.Option("--skip-product", help="Skip the product pipeline arm."), + ] = False, + skip_patient: Annotated[ + bool, + typer.Option( + "--skip-patient", + help=( + "Skip the patient pipeline arm. Note: leaves " + "tracker_metadata.complete=False for all trackers, so the " + "next --incremental run will re-queue everything." + ), + ), + ] = False, + incremental: Annotated[ + bool, + typer.Option( + "--incremental", + help=( + "Skip trackers whose MD5 + completion state match the previous " + "run's manifest. Both arms see the same filtered queue." + ), + ), + ] = False, + force: Annotated[ + bool, + typer.Option( + "--force", + help=( + "Wipe prior local outputs (raw, cleaned, tables) before each " + "pipeline arm runs. Without this flag, run-pipeline reuses any " + "existing per-tracker parquets on disk. Overrides --incremental " + "if both are passed." + ), + ), + ] = False, ): """Run the full end-to-end A4D pipeline. @@ -607,6 +1005,9 @@ def run_pipeline_cmd( # Skip Drive download if clinic_data.xlsx is already current uv run a4d run-pipeline --skip-drive-download + + # Wipe prior outputs before each arm runs + uv run a4d run-pipeline --force """ from a4d.config import settings from a4d.gcp.bigquery import load_pipeline_tables @@ -615,19 +1016,36 @@ def run_pipeline_cmd( from a4d.reference.loaders import find_reference_data_dir from a4d.tables.clinic import create_table_clinic_static + if skip_patient and skip_product: + console.print( + "[bold red]Error: --skip-patient and --skip-product are mutually exclusive[/bold red]\n" + ) + raise typer.Exit(1) + + if force and incremental: + console.print( + "[yellow]Warning: --incremental is ignored when --force is set[/yellow]" + ) + incremental = False + _workers = workers if workers is not None else settings.max_workers run_ts = datetime.now().strftime("%Y/%m/%d/%H%M%S") console.print("\n[bold blue]A4D Full Pipeline[/bold blue]\n") - console.print(f"Data root: {settings.data_root}") - console.print(f"Output root: {settings.output_root}") - console.print(f"Workers: {_workers}") - console.print(f"Project: {settings.project_id}") - console.print(f"Dataset: {settings.dataset}") - console.print(f"Drive: {'yes' if not skip_drive_download else 'skipped (--skip-drive-download)'}") - console.print(f"Download: {'yes' if not skip_download else 'skipped (--skip-download)'}") - console.print(f"Upload: {'yes' if not skip_upload else 'skipped (--skip-upload)'}") - console.print() + extras = [ + ("Project", str(settings.project_id)), + ("Dataset", str(settings.dataset)), + ("Drive", "yes" if not skip_drive_download else "skipped (--skip-drive-download)"), + ("Download", "yes" if not skip_download else "skipped (--skip-download)"), + ("Upload", "yes" if not skip_upload else "skipped (--skip-upload)"), + ("Product", "yes" if not skip_product else "skipped (--skip-product)"), + ("Patient", "yes" if not skip_patient else "skipped (--skip-patient)"), + ("Incremental", "yes" if incremental else "no"), + ("Force", "yes" if force else "no"), + ] + _render_pipeline_header( + settings.data_root, settings.output_root, _workers, extras=extras + ) # Step 0 – Download reference data from Google Drive if not skip_drive_download: @@ -655,35 +1073,66 @@ def run_pipeline_cmd( else: console.print("[bold]Step 1/5:[/bold] Skipping GCS download (--skip-download)\n") - # Step 2+3 – Extract, clean and build tables - console.print("[bold]Steps 2–3/5:[/bold] Processing tracker files...\n") - try: - result = run_patient_pipeline( - max_workers=_workers, - force=force, - show_progress=True, - console_log_level="WARNING", - ) - + # Resolve the tracker queue once. Without --incremental, pass None and let + # each orchestrator discover. With --incremental, discover + filter here so + # both arms see the same queue (single manifest load, coherent skip). + shared_tracker_files: list[Path] | None = None + if incremental: + all_trackers = discover_tracker_files(settings.data_root) + manifest = load_previous_manifest(settings.output_root) + shared_tracker_files, summary = filter_unchanged_trackers(all_trackers, manifest) console.print( - f" ✓ Processed {result.total_trackers} trackers " - f"({result.successful_trackers} ok, {result.failed_trackers} failed)\n" + f"[cyan]Incremental filter: queued {summary.queued}, " + f"skipped {summary.skipped} unchanged " + f"(new={summary.new}, changed={summary.changed}, " + f"incomplete={summary.previously_incomplete})[/cyan]\n" ) + if not shared_tracker_files: + console.print( + "[bold green]✓ No trackers need reprocessing — exiting cleanly[/bold green]\n" + ) + raise typer.Exit(0) + + # Step 2+3 – Extract, clean and build tables. + # clean_output wiring is `force` here, not `force or not incremental` like + # process-patient/process-product. Reason: run-pipeline's historical default + # (on `migration` and on this branch pre-change) was preserve-outputs — it + # never passed clean_output, inheriting the orchestrator's False default. + # --force on `migration` was a vestigial no-op (declared, plumbed, never + # read). With --force now actually wired through, opting in wipes both arms; + # without it, run-pipeline keeps its prior preserve-outputs contract. + if not skip_patient: + console.print("[bold]Steps 2–3/5:[/bold] Processing tracker files...\n") + try: + result = run_patient_pipeline( + tracker_files=shared_tracker_files, + max_workers=_workers, + clean_output=force, + show_progress=True, + console_log_level="WARNING", + ) - if result.failed_trackers > 0: - console.print("[bold yellow]Failed trackers:[/bold yellow]") - for tr in result.tracker_results: - if not tr.success: - console.print(f" • {tr.tracker_file.name}: {tr.error}") - console.print() + console.print( + f" ✓ Processed {result.total_trackers} trackers " + f"({result.successful_trackers} ok, {result.failed_trackers} failed)\n" + ) - if not result.success: - console.print("[bold red]✗ Pipeline failed – aborting upload steps[/bold red]\n") - raise typer.Exit(1) + _render_failed_trackers( + result, + mode="bullets", + title="Failed trackers", + leading_newline=False, + ) - except Exception as e: - console.print(f"\n[bold red]Error during processing: {e}[/bold red]\n") - raise typer.Exit(1) from e + if not result.success: + console.print("[bold red]✗ Pipeline failed – aborting upload steps[/bold red]\n") + raise typer.Exit(1) + + except Exception as e: + console.print(f"\n[bold red]Error during processing: {e}[/bold red]\n") + raise typer.Exit(1) from e + else: + console.print("[bold]Steps 2–3/5:[/bold] Skipping patient pipeline (--skip-patient)\n") tables_dir = settings.output_root / "tables" logs_dir = settings.output_root / "logs" @@ -697,6 +1146,77 @@ def run_pipeline_cmd( console.print(f" [bold red]Error creating clinic static table: {e}[/bold red]\n") raise typer.Exit(1) from e + # Product pipeline arm — soft failure posture: a crash here warns and + # continues so patient outputs (already on disk) still get uploaded. + if not skip_product: + console.print("[bold]Step 3c/5:[/bold] Running product pipeline...\n") + # Drop any stale product table from a prior run before re-running. + # Without --force, run-pipeline preserves outputs (clean_output=False), + # so a crash mid-product would otherwise leave the previous run's + # parquet for upload. With --force the orchestrator wipes anyway, so + # this unlink is redundant in that case but harmless. + (settings.output_root / "tables" / "product_data.parquet").unlink(missing_ok=True) + try: + product_result = run_product_pipeline( + tracker_files=shared_tracker_files, + max_workers=_workers, + clean_output=force, + show_progress=True, + console_log_level="WARNING", + ) + console.print( + f" ✓ Processed {product_result.total_trackers} product trackers " + f"({product_result.successful_trackers} ok, {product_result.failed_trackers} failed)\n" + ) + _render_failed_trackers( + product_result, + mode="bullets", + title="Failed product trackers", + leading_newline=False, + ) + except Exception as e: + console.print( + f"[bold yellow]Warning: product pipeline failed: {e}[/bold yellow]\n" + "[yellow]Continuing with patient outputs only.[/yellow]\n" + ) + else: + console.print("[bold]Step 3c/5:[/bold] Skipping product pipeline (--skip-product)\n") + + # Tracker metadata table — MD5 + per-tracker output presence. + # Not a skip-gated step; it's cheap and summarises the run's final state. + if settings.data_root.exists(): + console.print("[bold]Step 3d/5:[/bold] Creating tracker metadata table...\n") + try: + from a4d.tables.metadata import create_table_tracker_metadata + + create_table_tracker_metadata(settings.data_root, settings.output_root) + console.print(" ✓ Tracker metadata table created\n") + except Exception as e: + console.print( + f" [bold yellow]Warning: tracker metadata failed: {e}[/bold yellow]\n" + ) + + # Step 3e – Product-patient link validation (logging-only, post-tables). + # Skips silently if either arm's table is missing (e.g. --skip-product), or + # when --skip-patient leaves a stale patient_data_static.parquet on disk + # whose contents don't match this run's product output. + product_table = tables_dir / "product_data.parquet" + patient_static = tables_dir / "patient_data_static.parquet" + if not skip_patient and product_table.exists() and patient_static.exists(): + console.print("[bold]Step 3e/5:[/bold] Validating product-patient links...") + try: + from a4d.tables.product import link_product_patient + + product_df = pl.read_parquet(product_table) + mismatched = link_product_patient(product_df, patient_static) + console.print( + f" ✓ Link validation complete ({mismatched} unmatched product rows)\n" + ) + except Exception as e: + console.print( + f" [bold yellow]Warning: link validation failed: {e}[/bold yellow]\n" + ) + # Step 4 – Upload tables/ and logs/ to GCS under a timestamped prefix # Each run gets an isolated path: YYYY/MM/DD/HHMMSS/tables/ and .../logs/ # This avoids overwriting previous runs and keeps objectCreator permission sufficient. diff --git a/src/a4d/config.py b/src/a4d/config.py index c550c8b..1624487 100644 --- a/src/a4d/config.py +++ b/src/a4d/config.py @@ -21,7 +21,6 @@ class Settings(BaseSettings): model_config = SettingsConfigDict( env_file=".env", - env_file_encoding="utf-8", env_prefix="A4D_", case_sensitive=False, extra="ignore", @@ -48,6 +47,10 @@ class Settings(BaseSettings): error_val_character: str = "Undefined" error_val_date: str = "9999-09-09" + # Accepted tracker year range (raise on out-of-range sheet-name or filename) + min_tracker_year: int = 2017 + max_tracker_year: int = 2030 + @property def output_root(self) -> Path: """Computed output root path.""" diff --git a/src/a4d/errors.py b/src/a4d/errors.py index 11dc45b..719aab9 100644 --- a/src/a4d/errors.py +++ b/src/a4d/errors.py @@ -23,6 +23,7 @@ "invalid_tracker", # Tracker-level issues (missing columns, etc.) "function_call", # Generic function execution error "critical_abort", # Fatal error, tracker cannot be processed + "typo_rescued", # Known source-data typo substituted before parsing (informational) ] diff --git a/src/a4d/extract/common.py b/src/a4d/extract/common.py new file mode 100644 index 0000000..c2e9cf9 --- /dev/null +++ b/src/a4d/extract/common.py @@ -0,0 +1,201 @@ +"""Shared tracker-level extraction plumbing used by patient and product arms. + +Holds helpers that are not patient-specific or product-specific: detecting +month sheets in a workbook, parsing the tracker year, normalising Excel +error strings, and parsing the sheet-name month suffix. +""" + +import calendar +import re +from pathlib import Path + +import polars as pl +from loguru import logger + +from a4d.config import settings + + +def get_tracker_year(tracker_file: Path, month_sheets: list[str]) -> int: + """Extract tracker year from month sheet names or filename. + + Tries to parse year from month sheet names (e.g., "Jan24" -> 2024). + Falls back to extracting from filename if parsing fails. + Validates year is in reasonable range (2017-2030). + + Args: + tracker_file: Path to the tracker Excel file + month_sheets: List of month sheet names + + Returns: + Year of the tracker (e.g., 2024) + + Raises: + ValueError: If year cannot be determined or is out of valid range + + Example: + >>> get_tracker_year(Path("2024_Clinic.xlsx"), ["Jan24", "Feb24"]) + 2024 + """ + for sheet in month_sheets: + match = re.search(r"(\d{2})$", sheet) + if match: + year_suffix = int(match.group(1)) + year = 2000 + year_suffix # Assume 20xx until 2100 + logger.debug(f"Parsed year {year} from sheet name '{sheet}'") + + if not (settings.min_tracker_year <= year <= settings.max_tracker_year): + raise ValueError( + f"Year {year} is out of valid range " + f"({settings.min_tracker_year}-{settings.max_tracker_year}). " + f"Parsed from sheet name '{sheet}'" + ) + + return year + + match = re.search(r"(\d{4})", tracker_file.name) + if match: + year = int(match.group(1)) + logger.debug(f"Parsed year {year} from filename '{tracker_file.name}'") + + if not (settings.min_tracker_year <= year <= settings.max_tracker_year): + raise ValueError( + f"Year {year} is out of valid range " + f"({settings.min_tracker_year}-{settings.max_tracker_year}). " + f"Parsed from filename '{tracker_file.name}'" + ) + + return year + + raise ValueError( + f"Could not determine year from month sheets {month_sheets} or filename {tracker_file.name}" + ) + + +def find_month_sheets(workbook) -> list[str]: + """Find all month sheets in the tracker workbook. + + Month sheets are identified by matching against month abbreviations + (Jan, Feb, Mar, etc.) and sorted by month number for consistent processing. + + Args: + workbook: openpyxl Workbook object + + Returns: + List of month sheet names found in the workbook, sorted by month number + (Jan=1, Feb=2, ..., Dec=12) + + Example: + >>> wb = load_workbook("tracker.xlsx") + >>> find_month_sheets(wb) + ['Jan24', 'Feb24', 'Mar24', ...] + """ + month_abbrs = list(calendar.month_abbr)[1:] # ['Jan', 'Feb', ...] + month_sheets = [] + + for sheet_name in workbook.sheetnames: + if any(sheet_name.startswith(abbr) for abbr in month_abbrs): + month_sheets.append(sheet_name) + + def get_month_number(sheet_name: str) -> int: + """Extract month number from sheet name (Jan=1, ..., Dec=12).""" + month_prefix = sheet_name[:3] + try: + return month_abbrs.index(month_prefix) + 1 + except ValueError: + return 999 # Push unrecognized sheets to end + + month_sheets.sort(key=get_month_number) + + logger.info(f"Found {len(month_sheets)} month sheets (sorted by month): {month_sheets}") + return month_sheets + + +def clean_excel_errors(df: pl.DataFrame) -> pl.DataFrame: + """Convert Excel error strings to NULL values. + + Excel error codes like #DIV/0!, #VALUE!, etc. are not usable values + and should be treated as missing data. + + Args: + df: DataFrame with potential Excel error strings + + Returns: + DataFrame with Excel errors converted to NULL + + Example: + >>> df = pl.DataFrame({"bmi": ["17.5", "#DIV/0!", "18.2"]}) + >>> clean_df = clean_excel_errors(df) + >>> clean_df["bmi"].to_list() + ['17.5', None, '18.2'] + """ + excel_errors = [ + "#DIV/0!", + "#VALUE!", + "#REF!", + "#NAME?", + "#NUM!", + "#N/A", + "#NULL!", + ] + + # Excel error strings only appear in String columns; filtering by dtype + # lets both patient and product callers share this helper regardless of + # which metadata columns (e.g. product_table_year:Float64) are present. + data_cols = [col for col in df.columns if df.schema[col] == pl.String] + + if not data_cols: + return df + + for error in excel_errors: + for col in data_cols: + count = (df[col] == error).sum() + if count > 0: + logger.debug(f"Converted {count} '{error}' values to NULL in column '{col}'") + + df = df.with_columns( + [ + pl.when(pl.col(col).is_in(excel_errors)).then(None).otherwise(pl.col(col)).alias(col) + for col in data_cols + ] + ) + + return df + + +def extract_tracker_month(sheet_name: str) -> int: + """Extract month number (1-12) from sheet name. + + Args: + sheet_name: Sheet name like "Jan24", "Feb24", etc. + + Returns: + Month number (1 for January, 2 for February, etc.) + + Raises: + ValueError: If month cannot be extracted or is out of valid range + + Example: + >>> extract_tracker_month("Jan24") + 1 + >>> extract_tracker_month("Dec23") + 12 + """ + month_abbrs = list(calendar.month_abbr)[1:] # ['Jan', 'Feb', ...] + + # Check first 3 characters + month_prefix = sheet_name[:3] + + if month_prefix in month_abbrs: + month_num = month_abbrs.index(month_prefix) + 1 # +1 because index is 0-based + + # Validate month is in valid range (1-12) + # This should always be true given the logic above, but check anyway for safety + if not (1 <= month_num <= 12): + raise ValueError( + f"Month number {month_num} is out of valid range (1-12). " + f"Parsed from sheet name '{sheet_name}'" + ) + + return month_num + + raise ValueError(f"Could not extract month from sheet name '{sheet_name}'") diff --git a/src/a4d/extract/patient.py b/src/a4d/extract/patient.py index 7c91a6d..dd2fb9d 100644 --- a/src/a4d/extract/patient.py +++ b/src/a4d/extract/patient.py @@ -4,7 +4,6 @@ evolved over the years with different formats and structures. """ -import calendar import re import warnings from pathlib import Path @@ -14,106 +13,26 @@ from openpyxl import load_workbook from a4d.errors import ErrorCollector +from a4d.extract.common import ( + clean_excel_errors, + extract_tracker_month, + find_month_sheets, + get_tracker_year, +) from a4d.reference.synonyms import ColumnMapper, load_patient_mapper +__all__ = [ + "clean_excel_errors", + "extract_tracker_month", + "find_month_sheets", + "get_tracker_year", +] + # Suppress openpyxl warnings about unsupported Excel features # We only read data, so these warnings are not actionable warnings.filterwarnings("ignore", category=UserWarning, module=r"openpyxl\..*") -def get_tracker_year(tracker_file: Path, month_sheets: list[str]) -> int: - """Extract tracker year from month sheet names or filename. - - Tries to parse year from month sheet names (e.g., "Jan24" -> 2024). - Falls back to extracting from filename if parsing fails. - Validates year is in reasonable range (2017-2030). - - Args: - tracker_file: Path to the tracker Excel file - month_sheets: List of month sheet names - - Returns: - Year of the tracker (e.g., 2024) - - Raises: - ValueError: If year cannot be determined or is out of valid range - - Example: - >>> get_tracker_year(Path("2024_Clinic.xlsx"), ["Jan24", "Feb24"]) - 2024 - """ - for sheet in month_sheets: - match = re.search(r"(\d{2})$", sheet) - if match: - year_suffix = int(match.group(1)) - year = 2000 + year_suffix # Assume 20xx until 2100 - logger.debug(f"Parsed year {year} from sheet name '{sheet}'") - - if not (2017 <= year <= 2030): # Match R pipeline validation - raise ValueError( - f"Year {year} is out of valid range (2017-2030). " - f"Parsed from sheet name '{sheet}'" - ) - - return year - - match = re.search(r"(\d{4})", tracker_file.name) - if match: - year = int(match.group(1)) - logger.debug(f"Parsed year {year} from filename '{tracker_file.name}'") - - if not (2017 <= year <= 2030): # Match R pipeline validation - raise ValueError( - f"Year {year} is out of valid range (2017-2030). " - f"Parsed from filename '{tracker_file.name}'" - ) - - return year - - raise ValueError( - f"Could not determine year from month sheets {month_sheets} or filename {tracker_file.name}" - ) - - -def find_month_sheets(workbook) -> list[str]: - """Find all month sheets in the tracker workbook. - - Month sheets are identified by matching against month abbreviations - (Jan, Feb, Mar, etc.) and sorted by month number for consistent processing. - - Args: - workbook: openpyxl Workbook object - - Returns: - List of month sheet names found in the workbook, sorted by month number - (Jan=1, Feb=2, ..., Dec=12) - - Example: - >>> wb = load_workbook("tracker.xlsx") - >>> find_month_sheets(wb) - ['Jan24', 'Feb24', 'Mar24', ...] - """ - month_abbrs = list(calendar.month_abbr)[1:] # ['Jan', 'Feb', ...] - month_sheets = [] - - for sheet_name in workbook.sheetnames: - if any(sheet_name.startswith(abbr) for abbr in month_abbrs): - month_sheets.append(sheet_name) - - def get_month_number(sheet_name: str) -> int: - """Extract month number from sheet name (Jan=1, ..., Dec=12).""" - month_prefix = sheet_name[:3] - try: - return month_abbrs.index(month_prefix) + 1 - except ValueError: - return 999 # Push unrecognized sheets to end - - month_sheets.sort(key=get_month_number) - - logger.info(f"Found {len(month_sheets)} month sheets (sorted by month): {month_sheets}") - return month_sheets - - def find_data_start_row(ws) -> int: """Find the first row containing patient data. @@ -389,63 +308,6 @@ def filter_valid_columns( return valid_headers, filtered_data -def clean_excel_errors(df: pl.DataFrame) -> pl.DataFrame: - """Convert Excel error strings to NULL values. - - Excel error codes like #DIV/0!, #VALUE!, etc. are not usable values - and should be treated as missing data. - - Args: - df: DataFrame with potential Excel error strings - - Returns: - DataFrame with Excel errors converted to NULL - - Example: - >>> df = pl.DataFrame({"bmi": ["17.5", "#DIV/0!", "18.2"]}) - >>> clean_df = clean_excel_errors(df) - >>> clean_df["bmi"].to_list() - ['17.5', None, '18.2'] - """ - excel_errors = [ - "#DIV/0!", - "#VALUE!", - "#REF!", - "#NAME?", - "#NUM!", - "#N/A", - "#NULL!", - ] - - metadata_cols = { - "tracker_year", - "tracker_month", - "clinic_id", - "patient_id", - "sheet_name", - "file_name", - } - data_cols = [col for col in df.columns if col not in metadata_cols] - - if not data_cols: - return df - - df = df.with_columns( - [ - pl.when(pl.col(col).is_in(excel_errors)).then(None).otherwise(pl.col(col)).alias(col) - for col in data_cols - ] - ) - - for error in excel_errors: - for col in data_cols: - count = (df[col] == error).sum() - if count > 0: - logger.debug(f"Converted {count} '{error}' values to NULL in column '{col}'") - - return df - - def extract_patient_data( tracker_file: Path, sheet_name: str, @@ -585,45 +447,6 @@ def harmonize_patient_data_columns( return renamed_df -def extract_tracker_month(sheet_name: str) -> int: - """Extract month number (1-12) from sheet name. - - Args: - sheet_name: Sheet name like "Jan24", "Feb24", etc. - - Returns: - Month number (1 for January, 2 for February, etc.) - - Raises: - ValueError: If month cannot be extracted or is out of valid range - - Example: - >>> extract_tracker_month("Jan24") - 1 - >>> extract_tracker_month("Dec23") - 12 - """ - month_abbrs = list(calendar.month_abbr)[1:] # ['Jan', 'Feb', ...] - - # Check first 3 characters - month_prefix = sheet_name[:3] - - if month_prefix in month_abbrs: - month_num = month_abbrs.index(month_prefix) + 1 # +1 because index is 0-based - - # Validate month is in valid range (1-12) - # This should always be true given the logic above, but check anyway for safety - if not (1 <= month_num <= 12): - raise ValueError( - f"Month number {month_num} is out of valid range (1-12). " - f"Parsed from sheet name '{sheet_name}'" - ) - - return month_num - - raise ValueError(f"Could not extract month from sheet name '{sheet_name}'") - - def read_all_patient_sheets( tracker_file: Path, mapper: ColumnMapper | None = None, diff --git a/src/a4d/extract/product.py b/src/a4d/extract/product.py new file mode 100644 index 0000000..30edc1c --- /dev/null +++ b/src/a4d/extract/product.py @@ -0,0 +1,431 @@ +"""Product data extraction from Excel tracker files. + +Mirrors `src/a4d/extract/patient.py` structure but targets the product +section of each month sheet. Covers R Script 1 steps 1.1-1.10. +""" + +import warnings +from pathlib import Path + +import openpyxl +import polars as pl +from loguru import logger + +from a4d.errors import ErrorCollector +from a4d.extract.common import ( + clean_excel_errors, + extract_tracker_month, + find_month_sheets, + get_tracker_year, +) +from a4d.extract.wide_format import handle_wide_format_cells, handle_wide_format_columns +from a4d.reference.synonyms import ColumnMapper, load_product_mapper + +warnings.filterwarnings("ignore", category=UserWarning, module=r"openpyxl\..*") + + +class ProductSectionNotFoundError(ValueError): + """Raised when a month sheet has no identifiable product section.""" + + +def find_product_section(ws) -> tuple[int, int]: + """Locate the start and end rows of the product data region (R step 1.1). + + The start row is the header row containing the product/date/received + keywords. The end row is the row immediately before the patient + recruitment / patient data summary section. + + Returns (start_row, end_row) as 1-indexed Excel rows, both inclusive. + + Raises: + ProductSectionNotFoundError: if either bound cannot be located. + """ + max_row = ws.max_row or 0 + max_col = ws.max_column or 50 + + def _norm(raw_row) -> list[str]: + return [" ".join(str(cell or "").lower().split()) for cell in raw_row] + + start_row: int | None = None + end_row: int | None = None + + # 3-row sliding window over the streaming iterator: prev, curr, next. + # We evaluate curr (at curr_excel_row) against prev and next contexts. + prev_norm: list[str] | None = None + curr_norm: list[str] | None = None + curr_excel_row: int = 0 + + for excel_row, raw in enumerate( + ws.iter_rows(min_row=1, max_row=max_row, max_col=max_col, values_only=True), + start=1, + ): + next_norm = _norm(raw) + + if curr_norm is not None: + if start_row is None: + has_product = any("product" in cell for cell in curr_norm) + has_desc_support = any("description of support" in cell for cell in curr_norm) + has_date = any("date" in cell for cell in curr_norm) + has_received = any("received" in cell for cell in curr_norm) + has_units_received = any("units received" in cell for cell in curr_norm) + + matches_2024 = has_product and has_date and has_received + matches_2019_2021 = has_product and has_date and has_units_received + matches_2017_2018 = has_desc_support and has_date and has_units_received + + if matches_2024 or matches_2019_2021 or matches_2017_2018: + start_row = curr_excel_row + elif curr_excel_row > start_row: + has_patient_name = any("patient name" in cell for cell in next_norm) + has_patient_id = any("patient id" in cell for cell in next_norm) or any( + cell.strip() == "id" for cell in next_norm + ) + has_recruitment = any("patient recruitment" in cell for cell in curr_norm) + has_data_summary_above = prev_norm is not None and any( + "patient data summary" in cell for cell in prev_norm + ) + + if (has_recruitment or has_data_summary_above) and has_patient_name and has_patient_id: + end_row = curr_excel_row - 1 + break + + prev_norm = curr_norm + curr_norm = next_norm + curr_excel_row = excel_row + + if start_row is None: + raise ProductSectionNotFoundError("Could not find start of product section") + if end_row is None: + raise ProductSectionNotFoundError("Could not find end of product section") + + return start_row, end_row + + +def extract_product_data(ws, start_row: int, end_row: int) -> pl.DataFrame: + """Read the product region and promote its first row to headers (R step 1.2). + + Returns a DataFrame with all columns typed as ``pl.String``. Type + coercion is deferred to the cleaning phase (Sprint 3). + """ + max_col = ws.max_column or 50 + + all_rows = list( + ws.iter_rows(min_row=start_row, max_row=end_row, max_col=max_col, values_only=True) + ) + + if len(all_rows) < 2: + return pl.DataFrame() + + header_raw = all_rows[0] + data_rows = all_rows[1:] + + last_col = 0 + for i in range(len(header_raw) - 1, -1, -1): + if header_raw[i] is not None or any( + row[i] is not None for row in data_rows if i < len(row) + ): + last_col = i + 1 + break + + if last_col == 0: + return pl.DataFrame() + + from collections import defaultdict + + col_names: list[str] = [] + for col_idx in range(last_col): + raw_name = header_raw[col_idx] + col_names.append(str(raw_name) if raw_name is not None else f"_unnamed_{col_idx}") + + header_positions: dict[str, list[int]] = defaultdict(list) + for idx, name in enumerate(col_names): + header_positions[name].append(idx) + + duplicated = [h for h, positions in header_positions.items() if len(positions) > 1] + if duplicated: + logger.debug(f"Merging {len(duplicated)} duplicate column groups: {duplicated}") + + data_dict: dict[str, list[str | None]] = {} + for header, positions in header_positions.items(): + values: list[str | None] = [] + for row in data_rows: + if len(positions) == 1: + pos = positions[0] + raw = row[pos] if pos < len(row) else None + values.append(str(raw) if raw is not None else None) + else: + parts: list[str] = [] + for pos in positions: + raw = row[pos] if pos < len(row) else None + text = str(raw) if raw is not None else None + if text: + parts.append(text) + values.append(",".join(parts) if parts else None) + data_dict[header] = values + + return pl.DataFrame(data_dict, schema=dict.fromkeys(data_dict, pl.String)) + + +def add_product_metadata( + df: pl.DataFrame, + sheet_name: str, + tracker_month: int, + tracker_year: int, + file_name: str, + clinic_id: str, +) -> pl.DataFrame: + """Append sheet/tracker metadata columns (R step 1.8).""" + return df.with_columns( + [ + pl.lit(f"{tracker_month:02d}", dtype=pl.String).alias("product_table_month"), + pl.lit(float(tracker_year), dtype=pl.Float64).alias("product_table_year"), + pl.lit(sheet_name, dtype=pl.String).alias("product_sheet_name"), + pl.lit(file_name, dtype=pl.String).alias("file_name"), + pl.lit(clinic_id, dtype=pl.String).alias("clinic_id"), + ] + ) + + +def remove_header_rows(df: pl.DataFrame) -> pl.DataFrame: + """Drop residual header rows and fully empty rows (R step 1.6).""" + if df.height == 0: + return df + + df = df.filter(~pl.all_horizontal(pl.all().is_null())) + + if "product" in df.columns: + df = df.filter( + pl.col("product").is_null() + | ~pl.col("product") + .str.strip_chars() + .str.to_lowercase() + .is_in(["product", "patient data summary"]) + ) + + return df + + +def replace_extra_totals(df: pl.DataFrame) -> pl.DataFrame: + """Null ``product_units_released`` after a ``Total`` column (R step 1.9).""" + if df.height == 0: + return df + + if "product_released_to" in df.columns: + # R uses trimws(which="left") here, but readxl already strips trailing + # whitespace on read; openpyxl preserves it. Strip both to match R output. + df = df.with_columns( + pl.col("product_released_to").str.strip_chars().alias("product_released_to") + ) + + if "product_units_released" not in df.columns: + return df + + col_idx = df.columns.index("product_units_released") + if col_idx < 2: + return df + + prev1 = df.columns[col_idx - 1] + prev2 = df.columns[col_idx - 2] + + def _total_mask(col_name: str) -> pl.Expr: + if df.schema[col_name] == pl.String: + return ( + pl.col(col_name) + .str.to_lowercase() + .str.contains("total", literal=True) + .fill_null(False) + ) + return pl.lit(False) + + mask = _total_mask(prev1) | _total_mask(prev2) + return df.with_columns( + pl.when(mask) + .then(None) + .otherwise(pl.col("product_units_released")) + .alias("product_units_released") + ) + + +def _count_orphan_released_units( + df: pl.DataFrame, + sheet_name: str, + file_name: str, + error_collector: ErrorCollector | None, +) -> None: + """R-parity warning for orphan ``product_units_released`` values. + + Counts harmonized rows where ``product_released_to`` is null/whitespace + yet ``product_units_released`` carries a value, and emits one + ErrorCollector entry per sheet. Mirrors R ``count_na_rows`` in + ``read_product_data.R`` (called before ``replace_extra_total_values_with_NA``). + + Whitespace-only ``product_released_to`` cells count as orphan: openpyxl + returns ``""`` for blank cells while ``_normalize_empty_strings_to_null`` + runs in clean rather than extract, so the check folds them into the null + branch here. + """ + if error_collector is None: + return + if "product_released_to" not in df.columns or "product_units_released" not in df.columns: + return + + released_to_blank = ( + pl.col("product_released_to").is_null() + | (pl.col("product_released_to").cast(pl.String).str.strip_chars() == "") + ) + has_released_units = pl.col("product_units_released").is_not_null() & ( + pl.col("product_units_released").cast(pl.String).str.strip_chars() != "" + ) + + count = df.filter(released_to_blank & has_released_units).height + if count == 0: + return + + logger.bind(error_code="invalid_tracker").warning( + f"Sheet '{sheet_name}' has {count} rows where product_released_to " + f"is missing next to product_units_released." + ) + error_collector.add_error( + file_name=file_name, + patient_id="unknown", + column="product_released_to", + original_value=str(count), + error_message=( + f"Sheet '{sheet_name}' has {count} rows where product_released_to " + f"is missing next to product_units_released." + ), + error_code="invalid_tracker", + script="script1", + function_name="read_product_data_step1", + ) + + +def _harmonize( + df: pl.DataFrame, + mapper: ColumnMapper, + sheet_name: str, + error_collector: ErrorCollector | None, + file_name: str, +) -> pl.DataFrame: + """Rename columns via the mapper then drop any column not in the synonym schema (R step 1.5).""" + unknown = [ + col + for col in df.columns + if not mapper.is_known_column(col) and col not in mapper.synonyms + ] + if unknown: + logger.bind(error_code="invalid_tracker").warning( + f"Sheet {sheet_name}: unknown column names: {unknown}." + ) + if error_collector is not None: + for col in unknown: + error_collector.add_error( + file_name=file_name, + patient_id="unknown", + column=col, + original_value=col, + error_message=f"Sheet {sheet_name}: unknown column '{col}'", + error_code="invalid_tracker", + script="script1", + function_name="harmonize_input_data_columns", + ) + + df = mapper.rename_columns(df) + known = set(mapper.synonyms.keys()) + keep = [c for c in df.columns if c in known] + return df.select(keep) if keep else df.clear() + + +def read_all_product_sheets( + tracker_file: Path, + mapper: ColumnMapper | None = None, + error_collector: ErrorCollector | None = None, +) -> pl.DataFrame: + """Run steps 1.1-1.10 across every month sheet, returning one combined DataFrame.""" + tracker_file = Path(tracker_file) + mapper = mapper or load_product_mapper() + + wb = openpyxl.load_workbook( + tracker_file, read_only=True, data_only=True, keep_vba=False, keep_links=False + ) + + month_sheets = find_month_sheets(wb) + if not month_sheets: + raise ValueError(f"No month sheets found in {tracker_file.name}") + + year = get_tracker_year(tracker_file, month_sheets) + filename = tracker_file.stem + clinic_id = tracker_file.parent.name + + per_sheet: list[pl.DataFrame] = [] + for sheet_name in month_sheets: + ws = wb[sheet_name] + try: + start, end = find_product_section(ws) + except ProductSectionNotFoundError as exc: + logger.bind(error_code="invalid_tracker").warning( + f"Sheet {sheet_name}: {exc}. Skipping." + ) + if error_collector is not None: + error_collector.add_error( + file_name=filename, + patient_id="unknown", + column="", + original_value="", + error_message=f"Sheet {sheet_name}: {exc}", + error_code="invalid_tracker", + script="script1", + function_name="find_product_section", + ) + continue + + df = extract_product_data(ws, start, end) + if df.height == 0: + continue + + df = handle_wide_format_columns(df, filename) + df = handle_wide_format_cells(df, filename) + + df = _harmonize(df, mapper, sheet_name, error_collector, filename) + if df.height == 0 or df.width == 0: + continue + + try: + month = extract_tracker_month(sheet_name) + except ValueError as exc: + logger.warning(f"Sheet {sheet_name}: month unparseable ({exc}). Skipping.") + continue + + df = remove_header_rows(df) + df = add_product_metadata(df, sheet_name, month, year, filename, clinic_id) + _count_orphan_released_units(df, sheet_name, filename, error_collector) + df = replace_extra_totals(df) + + if df.height > 0: + per_sheet.append(df) + + wb.close() + + if not per_sheet: + logger.bind(error_code="empty_product_data").warning( + f"Empty product data: no product section found in any sheet of {filename}." + ) + return pl.DataFrame() + + return clean_excel_errors(pl.concat(per_sheet, how="diagonal_relaxed")) + + +def export_product_raw( + df: pl.DataFrame, + tracker_file: Path, + output_dir: Path, +) -> Path: + """Write raw product DataFrame to ``{output_dir}/{tracker_name}_product_raw.parquet``.""" + tracker_file = Path(tracker_file) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + output_path = output_dir / f"{tracker_file.stem}_product_raw.parquet" + logger.info(f"Writing {df.height} rows to {output_path}") + df.write_parquet(output_path) + return output_path diff --git a/src/a4d/extract/wide_format.py b/src/a4d/extract/wide_format.py new file mode 100644 index 0000000..e79f9e7 --- /dev/null +++ b/src/a4d/extract/wide_format.py @@ -0,0 +1,145 @@ +"""Wide-format Mandalay tracker reshaping (R steps 1.4a and 1.4b). + +Two distinct formats exist in the Mandalay Children's Hospital trackers +that need long-format conversion before the main extraction flow +continues. +""" + +import re + +import polars as pl + + +def _rows_to_df(rows: list[dict], schema: dict) -> pl.DataFrame: + """Build a DataFrame from row-dicts using explicit schema (avoids inference pitfalls).""" + if not rows: + return pl.DataFrame(schema=schema) + columns = {name: [row.get(name) for row in rows] for name in schema} + return pl.DataFrame(columns, schema=schema) + +_RELEASED_TO_WIDE = "Released To (select from drop down list)" +_TOTAL = "Total Units Released" +_PER_PERSON = "Units Released per person" + + +def handle_wide_format_columns(df: pl.DataFrame, filename: str) -> pl.DataFrame: + """Expand 2020-2021 Mandalay wide-format columns into rows (R step 1.4a). + + Gate: both ``Total Units Released`` and ``Units Released per person`` + columns must be present. Otherwise return ``df`` unchanged. + """ + del filename # gate is column-based, not filename-based + + cols = df.columns + if _TOTAL not in cols or _PER_PERSON not in cols or _RELEASED_TO_WIDE not in cols: + return df + if df.height == 0: + return df + + start_col_idx = cols.index(_RELEASED_TO_WIDE) + 1 + end_col_idx = cols.index(_PER_PERSON) - 1 + if start_col_idx > end_col_idx: + return df + + intermediate_cols = cols[start_col_idx : end_col_idx + 1] + + date_col = "Date" if "Date" in cols else next((c for c in cols if "Date" in c), None) + + new_rows: list[dict] = [] + for row in df.iter_rows(named=True): + new_rows.append(dict(row)) + + total = row.get(_TOTAL) + per_person = row.get(_PER_PERSON) + + if total is None or per_person is None: + continue + if str(total).strip() == "" or str(per_person).strip() == "": + continue + if total == per_person: + continue + + for col in intermediate_cols: + cell = row.get(col) + if cell is None: + continue + if isinstance(cell, str) and cell.strip() == "": + continue + + new_row = dict.fromkeys(cols) + if date_col is not None: + new_row[date_col] = row.get(date_col) + new_row[_RELEASED_TO_WIDE] = cell + new_row[_PER_PERSON] = per_person + new_rows.append(new_row) + + return _rows_to_df(new_rows, dict(df.schema)) + + +def handle_wide_format_cells(df: pl.DataFrame, filename: str) -> pl.DataFrame: + """Split 2017-2019 Mandalay comma-separated ``Released To`` cells (R step 1.4b).""" + if not any(pat in filename for pat in ("2017_Mandalay", "2018_Mandalay", "2019_Mandalay")): + return df + if df.height == 0: + return df + + def _find_col(substr: str) -> str | None: + return next((c for c in df.columns if substr in c), None) + + released_to = _find_col("Released To") + units_released = _find_col("Units Released") + date_col = _find_col("Date") + received_from = _find_col("Received From") + + if not all([released_to, units_released, date_col, received_from]): + return df + + pattern = re.compile(r"(-\s*\d)|(,\s*)") + if not any( + v is not None and pattern.search(str(v)) for v in df[released_to].to_list() + ): + return df + + cols = df.columns + new_rows: list[dict] = [] + for row in df.iter_rows(named=True): + new_rows.append(dict(row)) + + cell = row.get(released_to) + if cell is None or "," not in str(cell): + continue + + for fragment in str(cell).split(","): + fragment = fragment.strip() + if not fragment: + continue + + if "-" in fragment: + name, qty = fragment.split("-", 1) + name = name.strip() + qty = qty.strip() + else: + name, qty = fragment, None + + new_row = dict.fromkeys(cols) + new_row[date_col] = row.get(date_col) + new_row[received_from] = row.get(received_from) + new_row[released_to] = name or None + new_row[units_released] = qty + new_rows.append(new_row) + + result = _rows_to_df(new_rows, dict(df.schema)) + + comma_mask = ( + pl.col(released_to) + .str.contains(",", literal=True) + .fill_null(False) + ) + result = result.with_columns( + [ + pl.when(comma_mask).then(None).otherwise(pl.col(units_released)).alias(units_released), + pl.when(comma_mask).then(None).otherwise(pl.col(released_to)).alias(released_to), + ] + ) + + return result diff --git a/src/a4d/gcp/bigquery.py b/src/a4d/gcp/bigquery.py index 0c1ea6e..b8efe9f 100644 --- a/src/a4d/gcp/bigquery.py +++ b/src/a4d/gcp/bigquery.py @@ -7,8 +7,9 @@ from pathlib import Path +import polars as pl from google.cloud import bigquery -from google.api_core.exceptions import NotFound +from google.api_core.exceptions import GoogleAPIError, NotFound from loguru import logger from a4d.config import settings @@ -36,8 +37,10 @@ "patient_data_static.parquet": "patient_data_static", "patient_data_monthly.parquet": "patient_data_monthly", "patient_data_annual.parquet": "patient_data_annual", + "product_data.parquet": "product_data", "clinic_data_static.parquet": "clinic_data_static", "table_logs.parquet": "logs", + "tracker_metadata.parquet": "tracker_metadata", } @@ -195,3 +198,83 @@ def load_pipeline_tables( logger.info(f"Successfully loaded {len(results)}/{len(PARQUET_TO_TABLE)} tables") return results + + +def select_tracker_metadata( + client: bigquery.Client | None = None, + dataset: str | None = None, + project_id: str | None = None, +) -> pl.DataFrame | None: + """Read ``file_name, clinic_code, md5, complete`` from BigQuery. + + Used by ``a4d.state.source.load_previous_manifest`` to query the previous + run's tracker manifest for incremental processing. + + Returns ``None`` (rather than raising) on: + + * authentication failure (``DefaultCredentialsError``) — caller falls back + to local parquet, + * missing table (``NotFound``) — first-ever run, no manifest yet, + * any other ``GoogleAPIError`` — network issues etc., fall back rather + than block the pipeline. + + Schema fallback: if the BQ table predates the ``complete`` column being + published, the query is retried without it and ``complete`` is synthesised + as ``False`` for every row — forces a full reprocess, which is the safe + default when manifest provenance is uncertain. + """ + project_id = project_id or settings.project_id + dataset = dataset or settings.dataset + + if client is None: + try: + client = get_bigquery_client(project_id) + except Exception as e: + logger.warning(f"BigQuery client unavailable, skipping metadata query: {e}") + return None + + table_ref = f"{project_id}.{dataset}.tracker_metadata" + full_query = f"SELECT file_name, clinic_code, md5, complete FROM `{table_ref}`" + + full_schema = { + "file_name": pl.Utf8, + "clinic_code": pl.Utf8, + "md5": pl.Utf8, + "complete": pl.Boolean, + } + + try: + rows = list(client.query(full_query).result()) + data = {col: [r[col] for r in rows] for col in full_schema} + return pl.DataFrame(data, schema=full_schema) + except NotFound: + logger.info(f"BigQuery table not found, no previous manifest: {table_ref}") + return None + except GoogleAPIError as e: + # Schema fallback: retry without `complete` if the column is missing + # in the deployed table. The synthesised complete=False forces a + # full reprocess. + message = str(e).lower() + if "complete" in message and ("unrecognized name" in message or "not found" in message): + logger.warning( + f"BigQuery {table_ref} missing 'complete' column; " + "retrying without it and forcing full reprocess" + ) + try: + rows = list(client.query( + f"SELECT file_name, clinic_code, md5 FROM `{table_ref}`" + ).result()) + fallback_schema = {k: v for k, v in full_schema.items() if k != "complete"} + data = {col: [r[col] for r in rows] for col in fallback_schema} + df = pl.DataFrame(data, schema=fallback_schema) + return df.with_columns(pl.lit(False).alias("complete")) + except GoogleAPIError as retry_err: + logger.warning( + f"BigQuery schema-fallback query failed: {retry_err}" + ) + return None + logger.warning(f"BigQuery query failed for {table_ref}: {e}") + return None + except Exception as e: + logger.warning(f"Unexpected error querying {table_ref}: {e}") + return None diff --git a/src/a4d/gcp/storage.py b/src/a4d/gcp/storage.py index 1dc1716..bf3aab8 100644 --- a/src/a4d/gcp/storage.py +++ b/src/a4d/gcp/storage.py @@ -145,8 +145,7 @@ def upload_output( def _blob_name(file_path: Path) -> str: relative = file_path.relative_to(source_dir) - name = f"{prefix}/{relative}" if prefix else str(relative) - return name.replace("\\", "/") + return f"{prefix}/{relative}" if prefix else str(relative) uploaded: list[str] = [] diff --git a/src/a4d/pipeline/patient.py b/src/a4d/pipeline/patient.py index d9192cc..4c96205 100644 --- a/src/a4d/pipeline/patient.py +++ b/src/a4d/pipeline/patient.py @@ -18,6 +18,7 @@ create_table_patient_data_annual, create_table_patient_data_monthly, create_table_patient_data_static, + read_cleaned_patient_data, ) @@ -88,18 +89,21 @@ def process_patient_tables(cleaned_dir: Path, output_dir: Path) -> dict[str, Pat logger.warning("No cleaned files found, skipping table creation") return {} + patient_data = read_cleaned_patient_data(cleaned_files) + logger.info(f"Loaded combined patient dataframe: {patient_data.shape}") + tables = {} logger.info("Creating static patient table") - static_path = create_table_patient_data_static(cleaned_files, output_dir) + static_path = create_table_patient_data_static(patient_data, output_dir) tables["static"] = static_path logger.info("Creating monthly patient table") - monthly_path = create_table_patient_data_monthly(cleaned_files, output_dir) + monthly_path = create_table_patient_data_monthly(patient_data, output_dir) tables["monthly"] = monthly_path logger.info("Creating annual patient table") - annual_path = create_table_patient_data_annual(cleaned_files, output_dir) + annual_path = create_table_patient_data_annual(patient_data, output_dir) tables["annual"] = annual_path logger.info(f"Created {len(tables)} patient tables") @@ -111,7 +115,6 @@ def run_patient_pipeline( max_workers: int = 1, output_root: Path | None = None, skip_tables: bool = False, - force: bool = False, clean_output: bool = False, progress_callback: Callable[[str, bool], None] | None = None, show_progress: bool = False, @@ -134,7 +137,6 @@ def run_patient_pipeline( max_workers: Number of parallel workers (1 = sequential) output_root: Output directory (None = use settings.output_root) skip_tables: If True, only extract + clean, skip table creation - force: If True, reprocess even if outputs exist clean_output: If True, wipe patient_data_raw/, patient_data_cleaned/, tables/ before run progress_callback: Optional callback(tracker_name, success) called after each tracker show_progress: If True, show tqdm progress bar diff --git a/src/a4d/pipeline/product.py b/src/a4d/pipeline/product.py new file mode 100644 index 0000000..f02ee7a --- /dev/null +++ b/src/a4d/pipeline/product.py @@ -0,0 +1,217 @@ +"""Product pipeline orchestration. + +Mirrors ``pipeline/patient.py``: per-tracker extract+clean (optionally +parallel), followed by final-table creation. +""" + +import os +import shutil +from collections.abc import Callable +from concurrent.futures import ProcessPoolExecutor, as_completed +from datetime import datetime +from pathlib import Path + +from loguru import logger +from tqdm import tqdm + +from a4d.config import settings +from a4d.logging import setup_logging +from a4d.pipeline.models import PipelineResult, TrackerResult +from a4d.pipeline.patient import discover_tracker_files +from a4d.pipeline.tracker import process_tracker_product +from a4d.tables.product import create_table_product_data + + +def _init_worker_logging(output_root: Path) -> None: + """Initialize logging for worker processes (called once per ProcessPoolExecutor worker).""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + pid = os.getpid() + setup_logging( + output_root=output_root, + log_name=f"worker_product_{timestamp}_pid{pid}", + console_level="ERROR", + ) + + +def process_product_tables(cleaned_dir: Path, output_dir: Path) -> dict[str, Path]: + """Create the final product table from cleaned parquets. + + Thin wrapper around ``tables.product.create_table_product_data``. + Unlike the patient pipeline (static/monthly/annual) the product pipeline + emits a single ``product_data`` table. + """ + logger.info("Creating final product table from cleaned data") + + cleaned_files = list(cleaned_dir.glob("*_product_cleaned.parquet")) + logger.info(f"Found {len(cleaned_files)} cleaned product parquet files") + + if not cleaned_files: + logger.warning("No cleaned product files found, skipping table creation") + return {} + + product_data_path = create_table_product_data(cleaned_files, output_dir) + return {"product_data": product_data_path} + + +def run_product_pipeline( + tracker_files: list[Path] | None = None, + max_workers: int = 1, + output_root: Path | None = None, + skip_tables: bool = False, + clean_output: bool = False, + progress_callback: Callable[[str, bool], None] | None = None, + show_progress: bool = False, + console_log_level: str | None = None, +) -> PipelineResult: + """Run the end-to-end product pipeline. + + Mirrors ``run_patient_pipeline`` argument-for-argument. + """ + if output_root is None: + output_root = settings.output_root + + if clean_output: + for subdir in ("product_data_raw", "product_data_cleaned"): + target = output_root / subdir + if target.exists(): + shutil.rmtree(target) + logger.info(f"Cleaned output directory: {target}") + # Patient tables share output_root/tables; remove only the product table. + product_table = output_root / "tables" / "product_data.parquet" + if product_table.exists(): + product_table.unlink() + logger.info(f"Cleaned product table: {product_table}") + + setup_logging( + output_root, + "pipeline_product", + console_level=console_log_level if console_log_level else "INFO", + ) + logger.info("Starting product pipeline") + logger.info(f"Output directory: {output_root}") + logger.info(f"Max workers: {max_workers}") + + if tracker_files is None: + logger.info(f"Discovering tracker files in: {settings.data_root}") + tracker_files = discover_tracker_files(settings.data_root) + else: + tracker_files = [Path(f) for f in tracker_files] + + logger.info(f"Found {len(tracker_files)} tracker files to process") + + if not tracker_files: + logger.warning("No tracker files found") + return PipelineResult.from_tracker_results([], {}) + + tracker_results: list[TrackerResult] = [] + + if max_workers == 1: + logger.info("Processing trackers sequentially") + + iterator = ( + tqdm(tracker_files, desc="Processing product trackers", unit="file") + if show_progress + else tracker_files + ) + + for tracker_file in iterator: + if isinstance(iterator, tqdm): + iterator.set_description(f"Processing {tracker_file.name}") + + result = process_tracker_product( + tracker_file=tracker_file, + output_root=output_root, + mapper=None, + ) + tracker_results.append(result) + + if progress_callback: + progress_callback(tracker_file.name, result.success) + + if result.success: + logger.info(f"✓ Successfully processed: {tracker_file.name}") + if show_progress: + tqdm.write(f"✓ {tracker_file.name}") + else: + logger.error(f"✗ Failed to process: {tracker_file.name} - {result.error}") + if show_progress: + tqdm.write(f"✗ {tracker_file.name}: {result.error}") + + else: + logger.info(f"Processing trackers in parallel ({max_workers} workers)") + with ProcessPoolExecutor( + max_workers=max_workers, initializer=_init_worker_logging, initargs=(output_root,) + ) as executor: + futures = { + executor.submit( + process_tracker_product, + tracker_file, + output_root, + None, + ): tracker_file + for tracker_file in tracker_files + } + + futures_iterator = as_completed(futures) + if show_progress: + futures_iterator = tqdm( + futures_iterator, + total=len(futures), + desc="Processing product trackers", + unit="file", + ) + + for future in futures_iterator: + tracker_file = futures[future] + try: + result = future.result() + tracker_results.append(result) + + if progress_callback: + progress_callback(tracker_file.name, result.success) + + if result.success: + logger.info(f"✓ Completed: {tracker_file.name}") + if show_progress: + tqdm.write(f"✓ {tracker_file.name}") + else: + logger.error(f"✗ Failed: {tracker_file.name} - {result.error}") + if show_progress: + tqdm.write(f"✗ {tracker_file.name}: {result.error}") + except Exception as e: + logger.exception(f"Exception processing {tracker_file.name}") + if show_progress: + tqdm.write(f"✗ {tracker_file.name}: Exception - {str(e)}") + tracker_results.append( + TrackerResult( + tracker_file=tracker_file, + tracker_name=tracker_file.stem, + success=False, + error=str(e), + ) + ) + + successful = sum(1 for r in tracker_results if r.success) + failed = len(tracker_results) - successful + logger.info(f"Tracker processing complete: {successful} successful, {failed} failed") + + tables: dict[str, Path] = {} + if not skip_tables: + try: + cleaned_dir = output_root / "product_data_cleaned" + tables_dir = output_root / "tables" + tables = process_product_tables(cleaned_dir, tables_dir) + logger.info(f"Created {len(tables)} product tables total") + except Exception: + logger.exception("Failed to create product tables") + else: + logger.info("Skipping product table creation (skip_tables=True)") + + result = PipelineResult.from_tracker_results(tracker_results, tables) + + if result.success: + logger.info("✓ Product pipeline completed successfully") + else: + logger.warning(f"✗ Product pipeline completed with {failed} failures") + + return result diff --git a/src/a4d/pipeline/tracker.py b/src/a4d/pipeline/tracker.py index e377ab5..1fb16e6 100644 --- a/src/a4d/pipeline/tracker.py +++ b/src/a4d/pipeline/tracker.py @@ -5,11 +5,13 @@ from loguru import logger from a4d.clean.patient import clean_patient_file +from a4d.clean.product import clean_product_file from a4d.errors import ErrorCollector from a4d.extract.patient import export_patient_raw, read_all_patient_sheets +from a4d.extract.product import export_product_raw, read_all_product_sheets from a4d.logging import file_logger from a4d.pipeline.models import TrackerResult -from a4d.reference.synonyms import ColumnMapper +from a4d.reference.synonyms import ColumnMapper, load_product_mapper def process_tracker_patient( @@ -111,3 +113,79 @@ def process_tracker_patient( success=False, error=str(e), ) + + +def process_tracker_product( + tracker_file: Path, output_root: Path, mapper: ColumnMapper | None = None +) -> TrackerResult: + """Process single tracker file: extract + clean product data. + + Mirrors ``process_tracker_patient``. Non-fatal cleaning errors keep + ``success=True`` and surface via ``error_breakdown``; only unhandled + exceptions set ``success=False``. + """ + tracker_name = tracker_file.stem + + try: + raw_dir = output_root / "product_data_raw" + cleaned_dir = output_root / "product_data_cleaned" + raw_dir.mkdir(parents=True, exist_ok=True) + cleaned_dir.mkdir(parents=True, exist_ok=True) + + cleaned_output = cleaned_dir / f"{tracker_name}_product_cleaned.parquet" + + with file_logger(f"{tracker_name}_product", output_root): + logger.info(f"Processing tracker: {tracker_file.name}") + + logger.info("Step 1: Extracting product data from Excel") + error_collector = ErrorCollector() + + mapper = mapper or load_product_mapper() + + df_raw = read_all_product_sheets( + tracker_file=tracker_file, + mapper=mapper, + error_collector=error_collector, + ) + logger.info(f"Extracted {len(df_raw)} rows") + + raw_output = export_product_raw( + df=df_raw, tracker_file=tracker_file, output_dir=raw_dir + ) + logger.info(f"Raw parquet saved: {raw_output}") + + logger.info("Step 2: Cleaning product data") + clean_product_file( + raw_parquet_path=raw_output, + output_parquet_path=cleaned_output, + error_collector=error_collector, + ) + + error_count = len(error_collector) + error_breakdown = error_collector.get_error_summary() + logger.info(f"Cleaned parquet saved: {cleaned_output}") + logger.info(f"Total data quality errors: {error_count}") + if error_breakdown: + logger.info(f"Error breakdown: {error_breakdown}") + + return TrackerResult( + tracker_file=tracker_file, + tracker_name=tracker_name, + raw_output=raw_output, + cleaned_output=cleaned_output, + success=True, + error=None, + cleaning_errors=error_count, + error_breakdown=error_breakdown if error_breakdown else None, + ) + + except Exception as e: + logger.bind(error_code="critical_abort").exception(f"Failed to process tracker: {tracker_file.name}") + return TrackerResult( + tracker_file=tracker_file, + tracker_name=tracker_name, + raw_output=None, + cleaned_output=None, + success=False, + error=str(e), + ) diff --git a/src/a4d/reference/products.py b/src/a4d/reference/products.py new file mode 100644 index 0000000..ed30e05 --- /dev/null +++ b/src/a4d/reference/products.py @@ -0,0 +1,77 @@ +"""Product reference-data loaders (R cleaning steps 2.19 and 2.20). + +Loads the known-products list and product category mapping from the +``Stock_Summary`` sheet of ``master_tracker_variables.xlsx`` in the +shared ``reference_data/`` directory. +""" + +from functools import lru_cache + +import openpyxl +import polars as pl + +from a4d.reference.loaders import get_reference_data_path + + +@lru_cache(maxsize=1) +def _read_stock_summary() -> pl.DataFrame: + """Read the Stock_Summary sheet as a two-column DataFrame. + + Returns a frame with columns ``product`` (lowercased values) and + ``product_category`` (original casing). Null rows and duplicates + are dropped. + """ + xlsx_path = get_reference_data_path("master_tracker_variables.xlsx") + wb = openpyxl.load_workbook(xlsx_path, data_only=True, read_only=True) + try: + ws = wb["Stock_Summary"] + rows = list(ws.iter_rows(min_row=2, values_only=True)) + finally: + wb.close() + + products: list[str] = [] + categories: list[str | None] = [] + for row in rows: + if not row: + continue + name = row[0] + if name is None: + continue + name_str = str(name).strip() + if not name_str: + continue + products.append(name_str.lower()) + cat = row[1] if len(row) > 1 else None + categories.append(str(cat).strip() if cat is not None else None) + + return pl.DataFrame( + {"product": products, "product_category": categories} + ).unique(subset=["product"], keep="first", maintain_order=True) + + +def load_known_products() -> list[str]: + """Load the lowercased list of known product names. + + Covers the reference data read for R step 2.19 + (``report_unknown_products``). Reads the ``Stock_Summary`` sheet of + ``reference_data/master_tracker_variables.xlsx`` and returns each + product name lowercased. + + Returns: + List of lowercased product names. + """ + return _read_stock_summary()["product"].to_list() + + +def load_product_categories() -> pl.DataFrame: + """Load the product-to-category mapping. + + Covers the reference data read for R step 2.20 + (``add_product_categories``). Reads the ``Stock_Summary`` sheet of + ``reference_data/master_tracker_variables.xlsx`` and returns a + two-column DataFrame. + + Returns: + DataFrame with columns ``product`` and ``product_category``. + """ + return _read_stock_summary() diff --git a/src/a4d/reference/synonyms.py b/src/a4d/reference/synonyms.py index 5bf9883..411e9e8 100644 --- a/src/a4d/reference/synonyms.py +++ b/src/a4d/reference/synonyms.py @@ -5,6 +5,7 @@ """ import re +from functools import lru_cache from pathlib import Path import polars as pl @@ -112,6 +113,9 @@ def _build_lookup(self) -> dict[str, str]: if not synonym_list: continue + if isinstance(synonym_list, str): + synonym_list = [synonym_list] + for synonym in synonym_list: # Sanitize the synonym key before adding to lookup sanitized_key = sanitize_str(synonym) @@ -297,9 +301,13 @@ def validate_required_columns( raise ValueError(f"Required columns missing after renaming: {missing}") +@lru_cache(maxsize=1) def load_patient_mapper() -> ColumnMapper: """Load the patient data column mapper. + Cached so callers on the per-tracker hot path don't re-read the YAML. + Cache is per-process; ProcessPoolExecutor workers get their own. + Returns: ColumnMapper for patient data @@ -311,6 +319,7 @@ def load_patient_mapper() -> ColumnMapper: return ColumnMapper(path) +@lru_cache(maxsize=1) def load_product_mapper() -> ColumnMapper: """Load the product data column mapper. diff --git a/src/a4d/state/__init__.py b/src/a4d/state/__init__.py index e69de29..2a934fc 100644 --- a/src/a4d/state/__init__.py +++ b/src/a4d/state/__init__.py @@ -0,0 +1,15 @@ +"""Incremental processing: skip trackers whose MD5 + completion state match a +previous run's manifest. +""" + +from a4d.state.filter import FilterSummary, filter_unchanged_trackers +from a4d.state.manifest import Manifest, ManifestEntry +from a4d.state.source import load_previous_manifest + +__all__ = [ + "FilterSummary", + "Manifest", + "ManifestEntry", + "filter_unchanged_trackers", + "load_previous_manifest", +] diff --git a/src/a4d/state/filter.py b/src/a4d/state/filter.py new file mode 100644 index 0000000..ae0167e --- /dev/null +++ b/src/a4d/state/filter.py @@ -0,0 +1,83 @@ +"""Filter discovered tracker files against the previous run's manifest. + +The pipeline currently re-processes every ``.xlsx`` under ``data_root`` on every +run. With the ``--incremental`` flag, the CLI calls +:func:`filter_unchanged_trackers` between discovery and the worker loop to drop +trackers whose bytes haven't changed AND whose previous run completed all four +output stages. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + +from loguru import logger + +from a4d.state.manifest import Manifest +from a4d.tables.metadata import md5_file + + +@dataclass(frozen=True) +class FilterSummary: + queued: int + skipped: int + new: int + changed: int + previously_incomplete: int + + +def filter_unchanged_trackers( + tracker_files: list[Path], + manifest: Manifest, +) -> tuple[list[Path], FilterSummary]: + """Return trackers needing reprocessing, plus a counts summary. + + A tracker is **queued** (not skipped) when any of: + + * ``(clinic_code, file_name)`` is absent from the manifest (new tracker), + * current MD5 differs from the manifest MD5 (changed tracker), or + * manifest entry has ``complete=False`` (previous run didn't finish all + four output stages — re-run to fill the gaps). + + ``clinic_code`` is the parent folder name and ``file_name`` is the path + stem, matching ``tables/metadata.py``'s schema. MD5 is computed via the + same chunked helper used by the metadata producer, so hashes are + bit-comparable. + """ + queued: list[Path] = [] + new = changed = incomplete = 0 + + for path in tracker_files: + clinic_code = path.parent.name + file_name = path.stem + entry = manifest.get(clinic_code, file_name) + + if entry is None: + queued.append(path) + new += 1 + continue + if not entry.complete: + queued.append(path) + incomplete += 1 + continue + if md5_file(path) != entry.md5: + queued.append(path) + changed += 1 + continue + # Unchanged + complete: skip. + + summary = FilterSummary( + queued=len(queued), + skipped=len(tracker_files) - len(queued), + new=new, + changed=changed, + previously_incomplete=incomplete, + ) + logger.info( + f"Incremental filter: queued {summary.queued} " + f"(new={summary.new}, changed={summary.changed}, " + f"incomplete={summary.previously_incomplete}); " + f"skipped {summary.skipped} unchanged" + ) + return queued, summary diff --git a/src/a4d/state/manifest.py b/src/a4d/state/manifest.py new file mode 100644 index 0000000..c990476 --- /dev/null +++ b/src/a4d/state/manifest.py @@ -0,0 +1,38 @@ +"""Previous-run manifest used by incremental processing. + +A ``Manifest`` is an in-memory snapshot of ``tracker_metadata.parquet``: one +entry per (clinic_code, file_name) pair, capturing the previous run's MD5 and +whether the run completed (all four per-tracker output presence flags were +True). The filter step compares this snapshot against the current on-disk +trackers to decide which to re-process. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + + +ManifestKey = tuple[str, str] +"""(clinic_code, file_name) — file_name is the path stem, matching +``tracker_metadata.parquet``'s ``file_name`` column.""" + + +@dataclass(frozen=True) +class ManifestEntry: + md5: str + complete: bool + + +@dataclass(frozen=True) +class Manifest: + entries: dict[ManifestKey, ManifestEntry] = field(default_factory=dict) + + @classmethod + def empty(cls) -> Manifest: + return cls(entries={}) + + def get(self, clinic_code: str, file_name: str) -> ManifestEntry | None: + return self.entries.get((clinic_code, file_name)) + + def __len__(self) -> int: + return len(self.entries) diff --git a/src/a4d/state/source.py b/src/a4d/state/source.py new file mode 100644 index 0000000..bf42d76 --- /dev/null +++ b/src/a4d/state/source.py @@ -0,0 +1,72 @@ +"""Load the previous run's manifest from BigQuery or a local parquet. + +Source precedence: + +1. BigQuery (``select_tracker_metadata``) — authoritative when running in + Cloud Run with credentials available. +2. Local ``output_root/tables/tracker_metadata.parquet`` — fallback for + developers without ``gcloud auth`` configured, or when BQ is unreachable. +3. Empty manifest — first-ever run, every tracker queues. +""" + +from __future__ import annotations + +from pathlib import Path + +import polars as pl +from loguru import logger + +from a4d.state.manifest import Manifest, ManifestEntry + + +def _build_manifest(df: pl.DataFrame) -> Manifest: + """Convert a (file_name, clinic_code, md5, complete) frame into a Manifest.""" + entries: dict[tuple[str, str], ManifestEntry] = {} + for row in df.iter_rows(named=True): + key = (row["clinic_code"], row["file_name"]) + entries[key] = ManifestEntry(md5=row["md5"], complete=bool(row["complete"])) + return Manifest(entries=entries) + + +def load_previous_manifest( + output_root: Path, + *, + prefer_bigquery: bool = True, +) -> Manifest: + """Load the previous run's tracker manifest. + + Args: + output_root: Pipeline output root; the local fallback reads + ``output_root/tables/tracker_metadata.parquet`` from here. + prefer_bigquery: If True, try BigQuery first. Set to False to skip the + BQ lookup entirely (useful in tests or when running offline). + + Returns: + A populated ``Manifest`` on success, or ``Manifest.empty()`` when no + prior state is available. Never raises — every error path falls + through with a logged warning. + """ + if prefer_bigquery: + # Imported here so the BigQuery client isn't constructed at module + # import time. Keeps tests that patch the source module fast and + # avoids touching gcloud creds when prefer_bigquery=False. + from a4d.gcp.bigquery import select_tracker_metadata + + df = select_tracker_metadata() + if df is not None and df.height > 0: + logger.info(f"Loaded manifest from BigQuery ({df.height} rows)") + return _build_manifest(df) + + local_parquet = output_root / "tables" / "tracker_metadata.parquet" + if local_parquet.exists(): + try: + df = pl.read_parquet(local_parquet) + logger.info( + f"Loaded manifest from local parquet ({df.height} rows): {local_parquet}" + ) + return _build_manifest(df) + except Exception as e: + logger.warning(f"Failed to read local manifest {local_parquet}: {e}") + + logger.info("No previous manifest available; treating every tracker as new") + return Manifest.empty() diff --git a/src/a4d/tables/__init__.py b/src/a4d/tables/__init__.py index 434cbbb..af9d4f9 100644 --- a/src/a4d/tables/__init__.py +++ b/src/a4d/tables/__init__.py @@ -1,18 +1 @@ """Table creation module for final output tables.""" - -from a4d.tables.logs import create_table_logs, parse_log_file -from a4d.tables.patient import ( - create_table_patient_data_annual, - create_table_patient_data_monthly, - create_table_patient_data_static, - read_cleaned_patient_data, -) - -__all__ = [ - "create_table_patient_data_annual", - "create_table_patient_data_monthly", - "create_table_patient_data_static", - "read_cleaned_patient_data", - "create_table_logs", - "parse_log_file", -] diff --git a/src/a4d/tables/logs.py b/src/a4d/tables/logs.py index 692c1bc..e5de419 100644 --- a/src/a4d/tables/logs.py +++ b/src/a4d/tables/logs.py @@ -32,7 +32,7 @@ def parse_log_file(log_file: Path) -> pl.DataFrame: records = [] try: - with open(log_file, encoding="utf-8") as f: + with open(log_file) as f: for line_num, line in enumerate(f, 1): line = line.strip() @@ -118,8 +118,22 @@ def parse_log_file(log_file: Path) -> pl.DataFrame: if not records: return pl.DataFrame() - # Create DataFrame with proper types - df = pl.DataFrame(records) + # Create DataFrame with proper types. schema_overrides pins the dtype of + # nullable columns so a log file whose first ~100 rows happen to have + # error_code=None / no exception / no file_name doesn't infer pl.Null + # and fail when later rows append actual strings or when concat'd with + # log files where these columns *are* populated. + df = pl.DataFrame( + records, + schema_overrides={ + "error_code": pl.Utf8, + "file_name": pl.Utf8, + "tracker_year": pl.Int32, + "tracker_month": pl.Int32, + "exception_type": pl.Utf8, + "exception_value": pl.Utf8, + }, + ) # Cast categorical columns for efficiency df = df.with_columns( diff --git a/src/a4d/tables/metadata.py b/src/a4d/tables/metadata.py new file mode 100644 index 0000000..af958e0 --- /dev/null +++ b/src/a4d/tables/metadata.py @@ -0,0 +1,100 @@ +"""Tracker metadata table generator. + +Mirrors R's ``run_script_5_create_metadata_table.R``: for each ``.xlsx`` tracker +under ``data_root``, emits a row with an MD5 hash and presence flags for the +four per-tracker output subdirs (``patient_data_{raw,cleaned}`` and +``product_data_{raw,cleaned}``). + +The MD5 helper :func:`md5_file` is also imported by ``a4d.state.filter`` to +hash current trackers when filtering against the previous run's manifest. +""" + +from __future__ import annotations + +import hashlib +from datetime import datetime, timezone +from pathlib import Path + +import polars as pl +from loguru import logger + +# Order matches R's subdirs vector; BigQuery consumers may depend on it. +_SUBDIRS: tuple[str, ...] = ( + "patient_data_cleaned", + "patient_data_raw", + "product_data_cleaned", + "product_data_raw", +) + +_SCHEMA: dict[str, pl.DataType] = { + "file_name": pl.String, + "clinic_code": pl.String, + "md5": pl.String, + "patient_data_cleaned": pl.Boolean, + "patient_data_raw": pl.Boolean, + "product_data_cleaned": pl.Boolean, + "product_data_raw": pl.Boolean, + "complete": pl.Boolean, + "timestamp": pl.Datetime("us"), +} + + +def md5_file(path: Path, chunk_size: int = 65536) -> str: + """Stream-hash a file with MD5. ``usedforsecurity=False`` satisfies FIPS.""" + h = hashlib.md5(usedforsecurity=False) + with path.open("rb") as f: + for chunk in iter(lambda: f.read(chunk_size), b""): + h.update(chunk) + return h.hexdigest() + + +def create_table_tracker_metadata( + data_root: Path, + output_root: Path, +) -> Path: + """Emit ``tracker_metadata.parquet`` to ``output_root/tables/``. + + Args: + data_root: Directory holding tracker ``.xlsx`` files. Scanned recursively + so clinic-folder layouts are supported. + output_root: Pipeline output root. Per-tracker output presence is looked + up in its four subdirs (``_SUBDIRS``); the metadata parquet is + written to ``output_root/tables/``. + + Returns: + Path to the written parquet file. + """ + tracker_files = sorted(data_root.rglob("*.xlsx")) + logger.info(f"Building tracker metadata for {len(tracker_files)} tracker(s)") + + # Index each subdir once so per-tracker lookups are O(1) prefix checks. + subdir_names: dict[str, list[str]] = {} + for subdir in _SUBDIRS: + target = output_root / subdir + subdir_names[subdir] = ( + [p.name for p in target.iterdir() if p.is_file()] if target.exists() else [] + ) + + now = datetime.now(tz=timezone.utc).replace(tzinfo=None) + rows: list[dict] = [] + for tracker_path in tracker_files: + file_name = tracker_path.stem + row = { + "file_name": file_name, + "clinic_code": tracker_path.parent.name, + "md5": md5_file(tracker_path), + } + for subdir in _SUBDIRS: + row[subdir] = any(name.startswith(file_name) for name in subdir_names[subdir]) + row["complete"] = all(row[s] for s in _SUBDIRS) + row["timestamp"] = now + rows.append(row) + + df = pl.DataFrame(rows, schema=_SCHEMA) + + output_dir = output_root / "tables" + output_dir.mkdir(parents=True, exist_ok=True) + output_file = output_dir / "tracker_metadata.parquet" + df.write_parquet(output_file) + logger.info(f"Tracker metadata saved: {output_file} ({df.height} rows)") + return output_file diff --git a/src/a4d/tables/patient.py b/src/a4d/tables/patient.py index 1865a00..3697967 100644 --- a/src/a4d/tables/patient.py +++ b/src/a4d/tables/patient.py @@ -22,15 +22,15 @@ def read_cleaned_patient_data(cleaned_files: list[Path]) -> pl.DataFrame: return pl.concat(dfs, how="vertical") -def create_table_patient_data_static(cleaned_files: list[Path], output_dir: Path) -> Path: +def create_table_patient_data_static(patient_data: pl.DataFrame, output_dir: Path) -> Path: """Create static patient data table. - Reads all cleaned patient data and creates a single table with static columns - (data that doesn't change monthly). Groups by patient_id and takes the latest - available data (latest year and month). + Selects static columns (data that doesn't change monthly) from the already-loaded + cleaned patient dataframe. Groups by patient_id and takes the latest available + data (latest year and month). Args: - cleaned_files: List of paths to cleaned parquet files + patient_data: Combined cleaned patient dataframe (from read_cleaned_patient_data) output_dir: Directory to save output parquet file Returns: @@ -60,8 +60,6 @@ def create_table_patient_data_static(cleaned_files: list[Path], output_dir: Path "tracker_year", ] - patient_data = read_cleaned_patient_data(cleaned_files) - static_data = ( patient_data.select(static_columns) .sort(["patient_id", "tracker_year", "tracker_month"]) @@ -79,14 +77,14 @@ def create_table_patient_data_static(cleaned_files: list[Path], output_dir: Path return output_file -def create_table_patient_data_monthly(cleaned_files: list[Path], output_dir: Path) -> Path: +def create_table_patient_data_monthly(patient_data: pl.DataFrame, output_dir: Path) -> Path: """Create monthly patient data table. - Reads all cleaned patient data and creates a single table with dynamic columns - (data that changes monthly). Keeps all monthly records. + Selects dynamic monthly columns from the already-loaded cleaned patient + dataframe. Keeps all monthly records. Args: - cleaned_files: List of paths to cleaned parquet files + patient_data: Combined cleaned patient dataframe (from read_cleaned_patient_data) output_dir: Directory to save output parquet file Returns: @@ -127,8 +125,6 @@ def create_table_patient_data_monthly(cleaned_files: list[Path], output_dir: Pat "weight", ] - patient_data = read_cleaned_patient_data(cleaned_files) - monthly_data = patient_data.select(monthly_columns).sort( ["tracker_year", "tracker_month", "patient_id"] ) @@ -142,15 +138,15 @@ def create_table_patient_data_monthly(cleaned_files: list[Path], output_dir: Pat return output_file -def create_table_patient_data_annual(cleaned_files: list[Path], output_dir: Path) -> Path: +def create_table_patient_data_annual(patient_data: pl.DataFrame, output_dir: Path) -> Path: """Create annual patient data table. - Reads all cleaned patient data and creates a single table with annual columns - (data collected once per year). Groups by patient_id and tracker_year, taking - the latest month for each year. Only includes data from 2024 onwards. + Selects annual columns (data collected once per year) from the already-loaded + cleaned patient dataframe. Groups by patient_id and tracker_year, taking the + latest month for each year. Only includes data from 2024 onwards. Args: - cleaned_files: List of paths to cleaned parquet files + patient_data: Combined cleaned patient dataframe (from read_cleaned_patient_data) output_dir: Directory to save output parquet file Returns: @@ -193,8 +189,6 @@ def create_table_patient_data_annual(cleaned_files: list[Path], output_dir: Path "tracker_year", ] - patient_data = read_cleaned_patient_data(cleaned_files) - annual_data = ( patient_data.select(annual_columns) .filter(pl.col("tracker_year") >= 2024) diff --git a/src/a4d/tables/product.py b/src/a4d/tables/product.py new file mode 100644 index 0000000..6b48a4b --- /dev/null +++ b/src/a4d/tables/product.py @@ -0,0 +1,183 @@ +"""Product table creation. + +Covers R Script 3 (steps 3.1-3.3): merge all cleaned product parquets into +a single ``product_data`` table with the canonical 19-column schema. +""" + +from pathlib import Path + +import polars as pl +from loguru import logger + +from a4d.clean.converters import safe_convert_column +from a4d.clean.schema_product import apply_schema, get_product_data_schema +from a4d.clean.validators import fix_patient_id +from a4d.errors import ErrorCollector + + +def read_cleaned_product_data(cleaned_files: list[Path]) -> pl.DataFrame: + """Step 3.1 — read and diagonally concatenate cleaned product parquets. + + Uses ``pl.concat(..., how="diagonal")`` so trackers with missing columns + are filled with nulls rather than dropped. + + Args: + cleaned_files: List of cleaned product parquet paths. + + Returns: + Single DataFrame with the union of all tracker columns. + """ + if not cleaned_files: + raise ValueError("No cleaned product files provided") + + dfs = [pl.read_parquet(file) for file in cleaned_files] + return pl.concat(dfs, how="diagonal") + + +def create_table_product_data(cleaned_files: list[Path], output_dir: Path) -> Path: + """Build the final ``product_data`` table from cleaned tracker parquets. + + Covers R steps 3.1-3.3: + 1. Concatenate every cleaned product parquet. + 2. Preserve ``product_released_to`` as ``orig_product_released_to`` and + normalise ``product_released_to`` via ``fix_patient_id``. + 3. Apply the 19-column product schema with type enforcement; values that + fail conversion are replaced by the error sentinels from ``settings``. + + Args: + cleaned_files: Cleaned product parquet files from Script 2. + output_dir: Directory where ``product_data.parquet`` is written. + + Returns: + Path to the written ``product_data.parquet`` file. + """ + df = read_cleaned_product_data(cleaned_files) + + df = df.with_columns( + pl.col("product_released_to").alias("orig_product_released_to") + ) + + error_collector = ErrorCollector() + df = fix_patient_id( + df=df, + error_collector=error_collector, + patient_id_col="product_released_to", + ) + + # Adds any missing schema columns as typed nulls so the cast loop below + # has every target column available; safe_convert_column preserves order. + df = apply_schema(df) + + schema = get_product_data_schema() + for col, dtype in schema.items(): + if dtype in (pl.Int32, pl.Int64, pl.Float32, pl.Float64, pl.Date): + df = safe_convert_column( + df=df, + column=col, + target_type=dtype, + error_collector=error_collector, + patient_id_col="product_released_to", + ) + + # Surface any errors collected by fix_patient_id / safe_convert_column at + # the table-aggregation stage. These run outside any per-tracker file_logger, + # so without this loop the errors disappear silently. + for err in error_collector.errors: + logger.bind( + error_code=err.error_code, + script="script3", + function_name=err.function_name, + file_name=err.file_name, + column=err.column, + ).warning( + f"Table-stage error: {err.error_message} " + f"(file={err.file_name}, column={err.column}, value={err.original_value})" + ) + if error_collector.errors: + logger.info(f"Table-stage errors: {len(error_collector.errors)}") + + logger.info(f"Product data table dimensions: {df.shape}") + + output_dir.mkdir(parents=True, exist_ok=True) + output_path = output_dir / "product_data.parquet" + df.write_parquet(output_path) + + return output_path + + +def link_product_patient( + product_df: pl.DataFrame, + patient_static_path: Path, +) -> int: + """Validate product_released_to against patient_data_static. + + LEFT-joins product rows onto the patient table on + ``(file_name, product_released_to ↔ patient_id)`` and logs each + ``(file_name, product_released_to)`` pair that has no patient match. + Logging-only — does not modify either table and never raises. + + Mirrors R's ``script3_link_product_patient.R`` with one deviation: + rows where ``product_released_to`` is null or equals the + ``error_val_character`` sentinel ("Undefined") are filtered out + before joining. Those are not real patient IDs and would flood the + log with non-issues. + + Args: + product_df: Product table (post-``fix_patient_id``). + patient_static_path: Path to ``patient_data_static.parquet``. + + Returns: + Total count of mismatched product rows for telemetry/test use. + """ + from a4d.config import settings + + if not patient_static_path.exists(): + logger.warning( + f"Patient table not available at {patient_static_path}; " + "skipping product-patient link validation." + ) + return 0 + + patient_keys = ( + pl.read_parquet(patient_static_path) + .select(["file_name", "patient_id"]) + .with_columns(pl.lit(True).alias("_patient_matched")) + ) + + sentinel = settings.error_val_character + candidates = product_df.filter( + pl.col("product_released_to").is_not_null() + & (pl.col("product_released_to") != sentinel) + ) + + joined = candidates.join( + patient_keys, + left_on=["file_name", "product_released_to"], + right_on=["file_name", "patient_id"], + how="left", + ) + mismatches = joined.filter(pl.col("_patient_matched").is_null()) + + mismatch_groups = ( + mismatches.group_by(["file_name", "product_released_to"]) + .agg(pl.len().alias("count")) + .sort("count", descending=True) + ) + + total_mismatched_rows = mismatches.height + distinct_pairs = mismatch_groups.height + total_examined = candidates.height + + for row in mismatch_groups.iter_rows(named=True): + logger.warning( + f"Unmatched product_released_to: file_name='{row['file_name']}' " + f"product_id='{row['product_released_to']}' count={row['count']}" + ) + + logger.info( + f"Product-patient link validation: {total_mismatched_rows} mismatched rows, " + f"{distinct_pairs} distinct (file × id) pairs, " + f"{total_examined} candidate product rows examined." + ) + + return total_mismatched_rows diff --git a/src/a4d/validate/__init__.py b/src/a4d/validate/__init__.py new file mode 100644 index 0000000..c731706 --- /dev/null +++ b/src/a4d/validate/__init__.py @@ -0,0 +1,6 @@ +"""Post-pipeline validation that compares cleaned output against raw source. + +Used by ``scripts/validate_source_vs_output.py``. The validator does not import +or modify either golden cleaning module (clean/patient.py, clean/product.py); +it replicates only the small slices of logic it needs. +""" diff --git a/src/a4d/validate/common.py b/src/a4d/validate/common.py new file mode 100644 index 0000000..e215d62 --- /dev/null +++ b/src/a4d/validate/common.py @@ -0,0 +1,114 @@ +"""Shared helpers for the source-vs-output validators.""" + +from __future__ import annotations + +from typing import Any + +import polars as pl + +from a4d.clean.converters import safe_convert_column +from a4d.config import settings +from a4d.errors import ErrorCollector + +# Mirrors clean/patient.py:241-256. Hyphen->underscore happens first, then we +# extract the leading "LETTERS_NON-UNDERSCORE-CHARS" group. Single-token IDs +# (no underscore) pass through unchanged. +_NORMALIZE_REGEX = r"^([A-Z]+_[^_]+)" + + +def normalize_patient_id(col: pl.Expr) -> pl.Expr: + """Replicate the patient_id normalization in clean/patient.py:241-256. + + Returns an expression — caller wraps it in ``with_columns``. + """ + hyphen_to_underscore = col.str.replace_all("-", "_") + return ( + pl.when(hyphen_to_underscore.str.contains("_")) + .then(hyphen_to_underscore.str.extract(_NORMALIZE_REGEX, 1)) + .otherwise(hyphen_to_underscore) + ) + + +def safe_parse_series( + raw_col: pl.Series, + target_type: pl.DataType | type[pl.DataType], +) -> pl.Series: + """Re-parse a raw string Series to ``target_type``, discarding parse errors. + + ``safe_convert_column`` writes to its caller's ErrorCollector with code + ``type_conversion``. We pass a sacrificial collector here so those entries + don't pollute the validator's findings collector. + """ + name = raw_col.name + sacrificial = ErrorCollector() + df = pl.DataFrame({name: raw_col}) + parsed = safe_convert_column( + df=df, + column=name, + target_type=target_type, + error_collector=sacrificial, + # safe_convert_column reads file_name/patient_id for error attribution; we + # don't care about the entries it generates, but the columns must exist. + file_name_col="__unused_file__", + patient_id_col="__unused_pid__", + ) + series = parsed[name] + # safe_convert_column writes settings.error_val_numeric / error_val_character / + # error_val_date for parse failures; map those sentinels back to null so the + # validator can cleanly distinguish "parsed" from "failed to parse". + if series.dtype.is_numeric(): + return series.replace({settings.error_val_numeric: None}) + if series.dtype in (pl.Utf8, pl.String): + return series.replace({settings.error_val_character: None}) + if series.dtype == pl.Date: + sentinel_date = pl.Series("_s", [settings.error_val_date]).str.to_date()[0] + return series.replace({sentinel_date: None}) + return series + + +def emit_finding( + collector: ErrorCollector, + *, + file_name: str, + patient_id: str, + column: str, + original_value: Any, + error_message: str, + error_code: str, + function_name: str, +) -> None: + """Thin wrapper around ``ErrorCollector.add_error`` enforcing the schema. + + The collector accepts only the ErrorCode literal types — validator codes + that don't fit (MISSING_ROW, VALUE_SHIFT, ...) are encoded in + ``error_message`` and the underlying ``error_code`` is set to the closest + existing literal. Caller must pass one of: ``"missing_value"`` for + missing/phantom rows, ``"invalid_value"`` for shifts/range violations, + ``"type_conversion"`` for parse-driven nulls. + """ + collector.add_error( + file_name=file_name, + patient_id=patient_id, + column=column, + original_value="" if original_value is None else str(original_value), + error_message=error_message, + error_code=error_code, # type: ignore[arg-type] + function_name=function_name, + script="validate", + ) + + +def is_close(a: float | None, b: float | None, *, abs_tol: float = 1e-6, rel_tol: float = 1e-4) -> bool: + """Null-aware float comparison with abs and relative tolerance. + + Both null -> equal. Either-side null -> not equal. + """ + if a is None and b is None: + return True + if a is None or b is None: + return False + diff = abs(a - b) + if diff <= abs_tol: + return True + scale = max(abs(a), abs(b)) + return diff <= rel_tol * scale diff --git a/src/a4d/validate/source_vs_output_patient.py b/src/a4d/validate/source_vs_output_patient.py new file mode 100644 index 0000000..4c31d5e --- /dev/null +++ b/src/a4d/validate/source_vs_output_patient.py @@ -0,0 +1,296 @@ +"""Source-vs-output validation for the patient pipeline. + +See ``C:/Users/furin/.claude/plans/output-vs-source-validation-swirling-nygaard.md`` +for the design rationale and the catalog of checks. +""" + +from __future__ import annotations + +from pathlib import Path + +import polars as pl +from loguru import logger + +from a4d.clean.validators import load_numeric_ranges +from a4d.errors import ErrorCollector +from a4d.validate.common import ( + emit_finding, + is_close, + normalize_patient_id, + safe_parse_series, +) + +# Columns whose values the cleaner re-derives or transforms. Excluded from +# VALUE_SHIFT comparisons. +DERIVED_COLUMNS = frozenset({"age", "bmi", "t1d_diagnosis_age"}) + +# Columns the cleaner converts (height: cm->m; FBG: mmol<->mg). Excluded from +# VALUE_SHIFT in v1; promoted to a follow-up. +UNIT_CONVERTED_COLUMNS = frozenset({"height", "fbg_updated_mg", "fbg_updated_mmol"}) + +# Join key for row-level checks. Mirrors patient_data_monthly's natural grain. +ROW_KEY = ["clinic_id", "patient_id", "tracker_year", "tracker_month"] + + +def _load_raw(run_dir: Path) -> pl.DataFrame | None: + """Concatenate all per-tracker raw patient parquets in the run directory.""" + raw_dir = run_dir / "patient_data_raw" + if not raw_dir.exists(): + return None + files = sorted(raw_dir.glob("*_patient_raw.parquet")) + if not files: + return None + frames = [pl.read_parquet(f) for f in files] + return pl.concat(frames, how="diagonal") + + +def _load_cleaned_monthly(run_dir: Path) -> pl.DataFrame | None: + path = run_dir / "tables" / "patient_data_monthly.parquet" + if not path.exists(): + return None + return pl.read_parquet(path) + + +def _normalize_join_keys(df: pl.DataFrame) -> pl.DataFrame: + """Apply patient_id normalization and cast year/month for clean joins.""" + out = df.with_columns(normalize_patient_id(pl.col("patient_id")).alias("patient_id")) + return out.with_columns( + [ + pl.col("tracker_year").cast(pl.Int64, strict=False), + pl.col("tracker_month").cast(pl.Int64, strict=False), + ] + ) + + +def check_missing_patients( + raw: pl.DataFrame, cleaned: pl.DataFrame, collector: ErrorCollector +) -> None: + """MISSING_ROW + PHANTOM_ROW: anti-joins on the normalized row key.""" + raw_keys = _normalize_join_keys(raw).select(ROW_KEY).unique() + cleaned_keys = _normalize_join_keys(cleaned).select(ROW_KEY).unique() + + missing = raw_keys.join(cleaned_keys, on=ROW_KEY, how="anti") + for row in missing.iter_rows(named=True): + emit_finding( + collector, + file_name="", + patient_id=row["patient_id"] or "unknown", + column="__row__", + original_value=f"{row['clinic_id']}|{row['tracker_year']}|{row['tracker_month']}", + error_message=( + f"MISSING_ROW: raw row for {row['patient_id']} " + f"({row['tracker_year']}-{row['tracker_month']:02d}) " + f"not found in cleaned patient_data_monthly" + ), + error_code="missing_value", + function_name="check_missing_patients", + ) + + phantom = cleaned_keys.join(raw_keys, on=ROW_KEY, how="anti") + for row in phantom.iter_rows(named=True): + emit_finding( + collector, + file_name="", + patient_id=row["patient_id"] or "unknown", + column="__row__", + original_value=f"{row['clinic_id']}|{row['tracker_year']}|{row['tracker_month']}", + error_message=( + f"PHANTOM_ROW: cleaned row for {row['patient_id']} " + f"({row['tracker_year']}-{row['tracker_month']:02d}) " + f"has no matching raw source" + ), + error_code="missing_value", + function_name="check_missing_patients", + ) + + +def _join_for_cell_checks(raw: pl.DataFrame, cleaned: pl.DataFrame) -> pl.DataFrame: + """Inner join raw and cleaned on the normalized row key. + + Suffixes raw columns with ``_raw`` and cleaned columns with ``_clean``; + the join key columns are deduplicated under their original names. + """ + raw_n = _normalize_join_keys(raw) + cleaned_n = _normalize_join_keys(cleaned) + + raw_renamed = raw_n.rename( + {c: f"{c}_raw" for c in raw_n.columns if c not in ROW_KEY} + ) + cleaned_renamed = cleaned_n.rename( + {c: f"{c}_clean" for c in cleaned_n.columns if c not in ROW_KEY} + ) + return raw_renamed.join(cleaned_renamed, on=ROW_KEY, how="inner") + + +def check_unexpected_nulls( + joined: pl.DataFrame, collector: ErrorCollector, file_name_col: str = "file_name_raw" +) -> None: + """UNEXPECTED_NULL: raw cell non-empty, cleaned cell null. + + Each finding carries a ``was_parseable`` indicator in the message so + reviewers can filter to surprising cases (raw value parsed cleanly but the + cleaned cell is still null). + """ + cleaned_cols = [c for c in joined.columns if c.endswith("_clean")] + for clean_col in cleaned_cols: + base = clean_col[: -len("_clean")] + raw_col = f"{base}_raw" + if raw_col not in joined.columns: + continue + + target_dtype = joined.schema[clean_col] + # null in cleaned, non-empty (after strip) in raw + raw_str = pl.col(raw_col).cast(pl.Utf8).str.strip_chars() + suspect = joined.filter( + pl.col(clean_col).is_null() + & raw_str.is_not_null() + & (raw_str.str.len_chars() > 0) + ) + if suspect.is_empty(): + continue + + # Determine parseability for this column once via the sacrificial parser. + if target_dtype.is_numeric() or target_dtype == pl.Date or target_dtype == pl.Boolean: + parsed = safe_parse_series(suspect[raw_col].cast(pl.Utf8), target_dtype) + else: + parsed = suspect[raw_col].cast(pl.Utf8) + + for row, parsed_val in zip(suspect.iter_rows(named=True), parsed.to_list(), strict=True): + was_parseable = parsed_val is not None + emit_finding( + collector, + file_name=row.get(file_name_col) or "", + patient_id=row.get("patient_id") or "unknown", + column=base, + original_value=row[raw_col], + error_message=( + f"UNEXPECTED_NULL: raw={row[raw_col]!r} cleaned=null " + f"was_parseable={was_parseable}" + ), + error_code="missing_value", + function_name="check_unexpected_nulls", + ) + + +def check_value_shifts( + joined: pl.DataFrame, collector: ErrorCollector, file_name_col: str = "file_name_raw" +) -> None: + """VALUE_SHIFT: numeric, non-derived, non-unit-converted columns only. + + For each eligible column, re-parse raw via the sacrificial converter and + compare to cleaned with abs/rel tolerance. + """ + for clean_col in [c for c in joined.columns if c.endswith("_clean")]: + base = clean_col[: -len("_clean")] + if base in DERIVED_COLUMNS or base in UNIT_CONVERTED_COLUMNS: + continue + raw_col = f"{base}_raw" + if raw_col not in joined.columns: + continue + clean_dtype = joined.schema[clean_col] + if not clean_dtype.is_numeric(): + continue + + # Subset to rows where both sides have content; let the sacrificial + # parser handle the empty/whitespace -> null normalization. + subset = joined.select([raw_col, clean_col, "patient_id", file_name_col]) + parsed_raw = safe_parse_series(subset[raw_col].cast(pl.Utf8), clean_dtype) + + cleaned_vals = subset[clean_col].to_list() + raw_vals = parsed_raw.to_list() + pids = subset["patient_id"].to_list() + files = subset[file_name_col].to_list() + + for raw_v, clean_v, pid, fname in zip(raw_vals, cleaned_vals, pids, files, strict=True): + # Both null is fine; one-sided null is captured by other checks. + if raw_v is None or clean_v is None: + continue + if is_close(float(raw_v), float(clean_v)): + continue + emit_finding( + collector, + file_name=fname or "", + patient_id=pid or "unknown", + column=base, + original_value=raw_v, + error_message=f"VALUE_SHIFT: raw={raw_v} cleaned={clean_v}", + error_code="invalid_value", + function_name="check_value_shifts", + ) + + +def _apply_height_auto_conversion(s: pl.Series) -> pl.Series: + """Mirror clean/patient.py:530-535: divide values >2.3 by 100.""" + return pl.Series( + s.name, + [None if v is None else (v / 100.0 if v > 2.3 else v) for v in s.to_list()], + dtype=pl.Float64, + ) + + +def check_out_of_range( + raw: pl.DataFrame, collector: ErrorCollector, file_name_col: str = "file_name" +) -> None: + """OUT_OF_RANGE_RAW: raw values outside the YAML numeric_ranges. + + Runs against the raw frame because the cleaner replaces out-of-range + values with the sentinel before they reach the cleaned table. ``height`` + has the cm->m auto-conversion applied first to mirror the cleaner. + """ + ranges = load_numeric_ranges() + raw_n = _normalize_join_keys(raw) + for column, bounds in ranges.items(): + if column not in raw_n.columns: + continue + parsed = safe_parse_series(raw_n[column].cast(pl.Utf8), pl.Float64) + if column == "height": + parsed = _apply_height_auto_conversion(parsed) + + min_v = float(bounds["min"]) + max_v = float(bounds["max"]) + + pids = raw_n["patient_id"].to_list() + files = raw_n[file_name_col].to_list() if file_name_col in raw_n.columns else [""] * raw_n.height + for v, pid, fname in zip(parsed.to_list(), pids, files, strict=True): + if v is None: + continue + if v < min_v or v > max_v: + emit_finding( + collector, + file_name=fname or "", + patient_id=pid or "unknown", + column=column, + original_value=v, + error_message=( + f"OUT_OF_RANGE_RAW: value {v} outside [{min_v}, {max_v}]" + ), + error_code="invalid_value", + function_name="check_out_of_range", + ) + + +def validate_patient_run(run_dir: Path) -> ErrorCollector | None: + """Run all four patient checks against a pipeline run directory. + + Returns ``None`` if the run does not contain patient pipeline outputs + (e.g. V5 is product-only). + """ + raw = _load_raw(run_dir) + cleaned = _load_cleaned_monthly(run_dir) + if raw is None or cleaned is None: + logger.info(f"Patient pipeline outputs not found in {run_dir}; skipping.") + return None + + collector = ErrorCollector() + logger.info( + f"Patient validation: raw={raw.shape}, cleaned_monthly={cleaned.shape}" + ) + + check_missing_patients(raw, cleaned, collector) + joined = _join_for_cell_checks(raw, cleaned) + check_unexpected_nulls(joined, collector) + check_value_shifts(joined, collector) + check_out_of_range(raw, collector) + + logger.info(f"Patient validation: {len(collector)} findings") + return collector diff --git a/src/a4d/validate/source_vs_output_product.py b/src/a4d/validate/source_vs_output_product.py new file mode 100644 index 0000000..2dbfcb0 --- /dev/null +++ b/src/a4d/validate/source_vs_output_product.py @@ -0,0 +1,202 @@ +"""Source-vs-output validation for the product pipeline. + +Group-granularity checks only in v1: per-cell checks would require reproducing +cleaning steps 2.0-2.5 to reconstruct the per-(sheet, product) row index, which +we deliberately avoid (see the plan at +``C:/Users/furin/.claude/plans/output-vs-source-validation-swirling-nygaard.md``). +""" + +from __future__ import annotations + +from pathlib import Path + +import polars as pl +from loguru import logger + +from a4d.errors import ErrorCollector +from a4d.validate.common import emit_finding + +GROUP_KEY = ["file_name", "product_sheet_name", "product"] + +# Threshold for COLUMN_NULL_RATE_DELTA: emit a finding if cleaned null rate +# exceeds raw null rate by more than this fraction (5 percentage points). +NULL_RATE_DELTA_THRESHOLD = 0.05 + + +def _explode_multi_product_cells(df: pl.DataFrame) -> pl.DataFrame: + """Mirror clean/product.py:169-242 (the row-changing part only). + + Splits ``product`` on '; ' and ' and ', then ``.explode()``. We deliberately + skip the parenthetical-units extraction because it doesn't change row count + or the join keys. + """ + if "product" not in df.columns: + return df + df = df.with_columns( + pl.col("product").cast(pl.Utf8).str.replace_all(" and ", "; ").alias("product") + ) + return df.with_columns(pl.col("product").str.split("; ")).explode("product") + + +def _load_raw(run_dir: Path) -> pl.DataFrame | None: + raw_dir = run_dir / "product_data_raw" + if not raw_dir.exists(): + return None + files = sorted(raw_dir.glob("*_product_raw.parquet")) + if not files: + return None + frames = [pl.read_parquet(f) for f in files] + return pl.concat(frames, how="diagonal") + + +def _load_cleaned(run_dir: Path) -> pl.DataFrame | None: + path = run_dir / "tables" / "product_data.parquet" + if not path.exists(): + return None + return pl.read_parquet(path) + + +def check_missing_groups( + raw_exploded: pl.DataFrame, cleaned: pl.DataFrame, collector: ErrorCollector +) -> None: + """MISSING_GROUP + PHANTOM_GROUP: anti-joins on (file, sheet, product).""" + raw_keys = ( + raw_exploded.filter(pl.col("product").is_not_null()) + .select(GROUP_KEY) + .unique() + ) + cleaned_keys = ( + cleaned.filter(pl.col("product").is_not_null()).select(GROUP_KEY).unique() + ) + + missing = raw_keys.join(cleaned_keys, on=GROUP_KEY, how="anti") + for row in missing.iter_rows(named=True): + emit_finding( + collector, + file_name=row["file_name"] or "", + patient_id="", + column="__group__", + original_value=f"{row['product_sheet_name']}|{row['product']}", + error_message=( + f"MISSING_GROUP: raw (sheet={row['product_sheet_name']}, " + f"product={row['product']!r}) absent in cleaned product_data" + ), + error_code="missing_value", + function_name="check_missing_groups", + ) + + phantom = cleaned_keys.join(raw_keys, on=GROUP_KEY, how="anti") + for row in phantom.iter_rows(named=True): + emit_finding( + collector, + file_name=row["file_name"] or "", + patient_id="", + column="__group__", + original_value=f"{row['product_sheet_name']}|{row['product']}", + error_message=( + f"PHANTOM_GROUP: cleaned (sheet={row['product_sheet_name']}, " + f"product={row['product']!r}) has no matching raw source" + ), + error_code="missing_value", + function_name="check_missing_groups", + ) + + +def check_row_count_delta( + raw_exploded: pl.DataFrame, cleaned: pl.DataFrame, collector: ErrorCollector +) -> None: + """ROW_COUNT_DELTA: per-group row counts raw vs cleaned. + + Reported as a magnitude with no judgment about right/wrong — the cleaner + legitimately drops uninformative rows. Reviewers triage. + """ + raw_counts = ( + raw_exploded.filter(pl.col("product").is_not_null()) + .group_by(GROUP_KEY) + .len() + .rename({"len": "raw_count"}) + ) + cleaned_counts = ( + cleaned.filter(pl.col("product").is_not_null()) + .group_by(GROUP_KEY) + .len() + .rename({"len": "cleaned_count"}) + ) + joined = raw_counts.join(cleaned_counts, on=GROUP_KEY, how="inner") + deltas = joined.filter(pl.col("raw_count") != pl.col("cleaned_count")) + + for row in deltas.iter_rows(named=True): + delta = row["cleaned_count"] - row["raw_count"] + emit_finding( + collector, + file_name=row["file_name"] or "", + patient_id="", + column="__group__", + original_value=f"raw={row['raw_count']}|cleaned={row['cleaned_count']}", + error_message=( + f"ROW_COUNT_DELTA: (sheet={row['product_sheet_name']}, " + f"product={row['product']!r}) raw={row['raw_count']} " + f"cleaned={row['cleaned_count']} delta={delta:+d}" + ), + error_code="invalid_value", + function_name="check_row_count_delta", + ) + + +def check_column_null_rate_delta( + raw: pl.DataFrame, cleaned: pl.DataFrame, collector: ErrorCollector +) -> None: + """COLUMN_NULL_RATE_DELTA: per-column null rate raw vs cleaned across the run. + + Emit one finding per column where cleaned null rate exceeds raw null rate + by >NULL_RATE_DELTA_THRESHOLD. Coarse, but surfaces silent column-wide losses. + """ + raw_total = raw.height + cleaned_total = cleaned.height + if raw_total == 0 or cleaned_total == 0: + return + + common = [c for c in cleaned.columns if c in raw.columns] + for col in common: + raw_null_rate = raw[col].null_count() / raw_total + cleaned_null_rate = cleaned[col].null_count() / cleaned_total + delta = cleaned_null_rate - raw_null_rate + if delta > NULL_RATE_DELTA_THRESHOLD: + emit_finding( + collector, + file_name="", + patient_id="", + column=col, + original_value=f"raw_null_rate={raw_null_rate:.3f}", + error_message=( + f"COLUMN_NULL_RATE_DELTA: column={col!r} " + f"raw_null_rate={raw_null_rate:.3f} " + f"cleaned_null_rate={cleaned_null_rate:.3f} " + f"delta={delta:+.3f}" + ), + error_code="invalid_value", + function_name="check_column_null_rate_delta", + ) + + +def validate_product_run(run_dir: Path) -> ErrorCollector | None: + """Run the four product checks. Returns None if no product output exists.""" + raw = _load_raw(run_dir) + cleaned = _load_cleaned(run_dir) + if raw is None or cleaned is None: + logger.info(f"Product pipeline outputs not found in {run_dir}; skipping.") + return None + + collector = ErrorCollector() + raw_exploded = _explode_multi_product_cells(raw) + logger.info( + f"Product validation: raw={raw.shape} (exploded={raw_exploded.shape}), " + f"cleaned={cleaned.shape}" + ) + + check_missing_groups(raw_exploded, cleaned, collector) + check_row_count_delta(raw_exploded, cleaned, collector) + check_column_null_rate_delta(raw, cleaned, collector) + + logger.info(f"Product validation: {len(collector)} findings") + return collector diff --git a/tests/test_clean/test_converters.py b/tests/test_clean/test_converters.py index ab48665..caf6978 100644 --- a/tests/test_clean/test_converters.py +++ b/tests/test_clean/test_converters.py @@ -1,13 +1,17 @@ """Tests for type conversion with error tracking.""" +from datetime import date + import polars as pl from a4d.clean.converters import ( correct_decimal_sign, cut_numeric_value, + parse_date_column, safe_convert_column, safe_convert_multiple_columns, ) +from a4d.clean.date_parser import rescue_date_typos from a4d.config import settings from a4d.errors import ErrorCollector @@ -335,3 +339,80 @@ def test_cut_numeric_value_ignores_existing_errors(): # Only 30 should be logged, not the existing error value assert result["age"].to_list() == [15, settings.error_val_numeric, settings.error_val_numeric] assert len(collector) == 1 + + +def test_rescue_date_typos_known_patterns(): + assert rescue_date_typos("23-Mach-20") == ("23-MAR-20", True) + assert rescue_date_typos("15-N0v-2021") == ("15-NOV-2021", True) + assert rescue_date_typos("10-0ct-2024") == ("10-OCT-2024", True) + assert rescue_date_typos("01-N0vember-2021") == ("01-NOVEMBER-2021", True) + + +def test_rescue_date_typos_passthrough(): + assert rescue_date_typos("15-Mar-2024") == ("15-Mar-2024", False) + # Word-boundary protects unrelated substrings. + assert rescue_date_typos("CON0CTOR") == ("CON0CTOR", False) + + +def test_parse_date_column_rescues_typo_and_logs(): + df = pl.DataFrame( + { + "file_name": ["t.xlsx", "t.xlsx"], + "patient_id": ["P1", "P2"], + "entry_date": ["23-Mach-20", "15-Mar-2024"], + } + ) + collector = ErrorCollector() + + result = parse_date_column(df, "entry_date", collector) + + parsed = result["entry_date"].to_list() + assert parsed[0] == date(2020, 3, 23) + assert parsed[1] == date(2024, 3, 15) + assert len(collector) == 1 + err = collector.errors[0] + assert err.error_code == "typo_rescued" + assert err.column == "entry_date" + assert err.original_value == "23-Mach-20" + assert err.patient_id == "P1" + + +def test_parse_date_column_logs_unparseable_dates(): + """Pin parse_date_column's existing observability for genuinely unparseable + cells. R has a separate 'non_processed_dates' warning that fires on rows + R cannot parse via its narrower harmoniser; Python's parse_date_flexible + parses many of those rows successfully (e.g. "Mar 18" abbreviated formats) + and only sentinels truly-unparseable ones — at which point this existing + type_conversion log entry covers the equivalent signal with strictly + better signal-to-noise. See cleaning_divergences.md §10. + """ + df = pl.DataFrame( + { + "file_name": ["test.xlsx", "test.xlsx", "test.xlsx"], + "patient_id": ["P1", "P2", "P3"], + "entry_date": ["2024-03-15", "garbage_value_xyz", "2024-04-20"], + }, + schema={ + "file_name": pl.String, + "patient_id": pl.String, + "entry_date": pl.String, + }, + ) + collector = ErrorCollector() + + result = parse_date_column( + df=df, column="entry_date", error_collector=collector + ) + + parsed = result["entry_date"].to_list() + assert parsed[0] == date(2024, 3, 15) + assert parsed[1] == date(9999, 9, 9) # error_val_date sentinel + assert parsed[2] == date(2024, 4, 20) + + assert len(collector) == 1 + err = collector.errors[0] + assert err.error_code == "type_conversion" + assert err.function_name == "parse_date_column" + assert err.column == "entry_date" + assert err.original_value == "garbage_value_xyz" + assert err.patient_id == "P2" diff --git a/tests/test_clean/test_product.py b/tests/test_clean/test_product.py new file mode 100644 index 0000000..73e2a30 --- /dev/null +++ b/tests/test_clean/test_product.py @@ -0,0 +1,770 @@ +"""Tests for product cleaning helpers.""" + +from datetime import date + +import polars as pl + +from a4d.clean.product import ( + _check_entry_dates_match_sheet, + _clean_received_from, + _compute_running_balance, + _extract_balance_from_received, + _fill_product_names_and_sort, + _format_dates, + _split_multi_product_cells, + _switch_misplaced_columns, + _validate_entry_dates, + clean_product_data, +) +from a4d.errors import ErrorCollector + + +def _entry_date_df( + products: list[str], + entry_dates: list[date | None], + table_year: int = 2024, +) -> pl.DataFrame: + n = len(products) + return pl.DataFrame( + { + "product": products, + "product_entry_date": entry_dates, + "product_table_year": [table_year] * n, + "product_sheet_name": ["Jun"] * n, + "file_name": ["t.xlsx"] * n, + }, + schema={ + "product": pl.String, + "product_entry_date": pl.Date, + "product_table_year": pl.Int32, + "product_sheet_name": pl.String, + "file_name": pl.String, + }, + ) + + +def test_validate_entry_dates_flags_future_dates_within_window(): + df = _entry_date_df( + products=["P1", "P2", "P3"], + entry_dates=[date(2024, 6, 1), date(2099, 3, 15), None], + ) + collector = ErrorCollector() + + result = _validate_entry_dates(df, collector) + + parsed = result["product_entry_date"].to_list() + assert parsed[0] == date(2024, 6, 1) + assert parsed[1] == date(9999, 9, 9) + assert parsed[2] is None + assert len(collector) == 1 + err = collector.errors[0] + assert err.column == "product_entry_date" + assert err.error_code == "invalid_value" + assert err.patient_id == "P2" + + +def test_validate_entry_dates_preserves_buddhist_era_dates(): + df = _entry_date_df( + products=["P1", "P2"], + entry_dates=[date(2567, 11, 11), date(9999, 9, 9)], + ) + collector = ErrorCollector() + + result = _validate_entry_dates(df, collector) + + parsed = result["product_entry_date"].to_list() + assert parsed[0] == date(2567, 11, 11) + assert parsed[1] == date(9999, 9, 9) + assert len(collector) == 0 + + +def test_validate_entry_dates_missing_columns_is_noop(): + df = pl.DataFrame({"product": ["P1"]}) + collector = ErrorCollector() + + result = _validate_entry_dates(df, collector) + + assert result.equals(df) + assert len(collector) == 0 + + +def test_validate_entry_dates_logs_year_floor_but_preserves_date(): + """Year-floor violations (parsed.year < tracker_year - YEAR_FLOOR_DELTA) are + logged via error_collector but the parsed date is preserved in + product_entry_date. R does not validate; preserving the date here aligns + the downstream sort/cumsum trajectory with R while keeping the audit + trail.""" + df = _entry_date_df( + products=["P1", "P2", "P3"], + entry_dates=[date(2024, 6, 1), date(1967, 2, 5), None], + ) + collector = ErrorCollector() + + result = _validate_entry_dates(df, collector) + + parsed = result["product_entry_date"].to_list() + assert parsed[0] == date(2024, 6, 1) + assert parsed[1] == date(1967, 2, 5) + assert parsed[2] is None + assert len(collector) == 1 + err = collector.errors[0] + assert err.column == "product_entry_date" + assert err.error_code == "invalid_value" + assert err.patient_id == "P2" + assert "before" in err.error_message + + +def test_switch_misplaced_columns_scoped_per_sheet(): + """Swap fires only for sheets with 'Remaining Stock'; clean sheets in + the same file pass through untouched (regression for 2018_Penang_DC).""" + df = pl.DataFrame( + { + "product_sheet_name": ["Jul18", "Jul18", "Nov18", "Nov18"], + "product": ["A", "A", "A", "A"], + "product_units_received": ["70", None, "Remaining Stock", None], + "product_received_from": ["DKSH", None, "72", None], + }, + schema={ + "product_sheet_name": pl.String, + "product": pl.String, + "product_units_received": pl.String, + "product_received_from": pl.String, + }, + ) + + out = _switch_misplaced_columns(df) + + jul = out.filter(pl.col("product_sheet_name") == "Jul18") + assert jul["product_units_received"].to_list() == ["70", None] + assert jul["product_received_from"].to_list() == ["DKSH", None] + + nov = out.filter(pl.col("product_sheet_name") == "Nov18") + assert nov["product_units_received"].to_list() == ["72", None] + assert nov["product_received_from"].to_list() == ["Remaining Stock", None] + + +def test_extract_balance_from_received_scoped_per_sheet(): + """Rewrite fires only for sheets containing 'Balance' markers; clean + sheets in the same file pass through untouched (regression for 2018 + Preah Kossamak / 2019 Sultanah Bahiyah_DC / 2020 LFHC_DC, where Aug18 + donor names were blanked because Sep18+ used the Balance convention).""" + df = pl.DataFrame( + { + "product_sheet_name": ["Aug18", "Sep18", "Sep18"], + "product": ["A", "A", "A"], + "product_units_received": ["20", "Start Balance", None], + "product_units_released": [None, "2", None], + "product_received_from": ["GE100", None, None], + }, + schema={ + "product_sheet_name": pl.String, + "product": pl.String, + "product_units_received": pl.String, + "product_units_released": pl.String, + "product_received_from": pl.String, + }, + ) + + out = _extract_balance_from_received(df) + + aug = out.filter(pl.col("product_sheet_name") == "Aug18") + assert aug["product_received_from"].to_list() == ["GE100"] + assert aug["product_units_released"].to_list() == [None] + + sep = out.filter(pl.col("product_sheet_name") == "Sep18") + assert sep["product_received_from"].to_list() == ["2", None] + assert sep["product_units_released"].to_list() == [None, None] + + +def test_extract_balance_from_received_clears_released_on_all_balance_sheet(): + """Sheets where every row is a Balance marker still get units_released + cleared. Regression for 2019 PKH Oct19 / 2020 PKH May20 (18 rows): the + sheet_triggered predicate uses received_from.is_null().any().over(sheet), + which flips False after the first .with_columns() populates received_from + on every Balance row — so without caching, the second call's released- + clear pass becomes a no-op.""" + df = pl.DataFrame( + { + "product_sheet_name": ["Oct19", "Oct19"], + "product": ["GE100", "GE100"], + "product_units_received": ["Start Balance", "End Balance"], + "product_units_released": ["1", "1"], + "product_received_from": [None, None], + }, + schema={ + "product_sheet_name": pl.String, + "product": pl.String, + "product_units_received": pl.String, + "product_units_released": pl.String, + "product_received_from": pl.String, + }, + ) + + out = _extract_balance_from_received(df) + + assert out["product_received_from"].to_list() == ["1", "1"] + assert out["product_units_released"].to_list() == [None, None] + + +def test_extract_balance_from_received_preserves_supplier_on_triggered_sheet(): + """Non-Balance rows on a triggered sheet keep their supplier name. + Regression for the Mahosot 2020 DKSH stock-receipt rows: the + ``(?i)Balance`` trigger substring-matches "START BALANCE" / "END BALANCE", + firing on every standard sheet. Before the fix, the case_when's catch-all + nulled ``product_received_from`` on every non-Balance row in the triggered + sheet, blanking legitimate supplier names. After the fix the catch-all + preserves received_from.""" + df = pl.DataFrame( + { + "product_sheet_name": ["Mar20", "Mar20", "Mar20"], + "product": ["Performa", "Performa", "Performa"], + "product_units_received": ["START BALANCE", "100", "END BALANCE"], + "product_units_released": [None, None, None], + "product_received_from": ["43", "DKSH", "87"], + }, + schema={ + "product_sheet_name": pl.String, + "product": pl.String, + "product_units_received": pl.String, + "product_units_released": pl.String, + "product_received_from": pl.String, + }, + ) + + out = _extract_balance_from_received(df) + + # Balance-marker rows pass through unchanged (received_from non-null, + # so no released value to relocate; arm 2 is a no-op rewrite). + # Change row in the middle keeps DKSH instead of being nulled. + assert out["product_received_from"].to_list() == ["43", "DKSH", "87"] + # units_released stays null on all rows (was null going in). + assert out["product_units_released"].to_list() == [None, None, None] + + +def test_extract_balance_from_received_nulls_total_subtotal_label(): + """Typist subtotal-row labels ("Total" in product_received_from) are + nulled, while legitimate supplier names on the same triggered sheet + survive. Regression for the 47 corpus-wide subtotal rows in 2019 + Penang DC / VNCH / Mandalay etc. that the supplier-preservation fix + inadvertently kept; R nulled them implicitly via its no-default + case_when. "Total" is the label the typist puts on the end-of-product- + block subtotal row, never a supplier.""" + # Sheet must be "triggered": need at least one Balance marker AND at + # least one row with null received_from. Last row supplies the null. + df = pl.DataFrame( + { + "product_sheet_name": ["Mar20", "Mar20", "Mar20", "Mar20", "Mar20"], + "product": ["Performa", "Performa", "Performa", "Performa", "Performa"], + "product_units_received": ["START BALANCE", "100", "0", "END BALANCE", "50"], + "product_units_released": [None, None, None, None, None], + "product_received_from": ["43", "DKSH", "Total", "87", None], + }, + schema={ + "product_sheet_name": pl.String, + "product": pl.String, + "product_units_received": pl.String, + "product_units_released": pl.String, + "product_received_from": pl.String, + }, + ) + + out = _extract_balance_from_received(df) + + # DKSH (legitimate supplier) survives; Total (subtotal label) is nulled; + # null stays null; balance markers' received_from is unchanged. + assert out["product_received_from"].to_list() == ["43", "DKSH", None, "87", None] + + +def test_clean_received_from_copies_start_balance_and_strips_numeric_supplier(): + """Pin _clean_received_from's two paths: (1) START rows copy + received_from into product_balance; non-START rows preserve their + existing balance (defensive — earlier code wiped them, which only + happened to be a no-op because nothing populates product_balance + before this step). (2) numeric-only product_received_from is nulled + while alphabetic / mixed-alphanumeric values survive.""" + df = pl.DataFrame( + { + "product_units_received": ["START", "5", "3", "2"], + "product_received_from": ["DKSH", "DKSH", "43", "DKSH 43"], + "product_balance": [None, None, None, None], + "product_sheet_name": ["Jan", "Jan", "Jan", "Jan"], + }, + schema={ + "product_units_received": pl.String, + "product_received_from": pl.String, + "product_balance": pl.String, + "product_sheet_name": pl.String, + }, + ) + + out = _clean_received_from(df) + + # START row gets received_from copied into balance; other rows keep their + # original (null in production after apply_schema seed). + assert out["product_balance"].to_list() == ["DKSH", None, None, None] + # Alpha-strip on received_from: numeric-only "43" → null, alpha + mixed survive. + assert out["product_received_from"].to_list() == ["DKSH", "DKSH", None, "DKSH 43"] + + +def test_clean_received_from_no_start_preserves_balance(): + """When no row contains START, product_balance must be untouched. + Only the alpha-strip on received_from runs.""" + df = pl.DataFrame( + { + "product_units_received": ["5", "3"], + "product_received_from": ["DKSH", "43"], + "product_balance": ["100", "200"], # pre-existing sentinels + "product_sheet_name": ["Jan", "Jan"], + }, + schema={ + "product_units_received": pl.String, + "product_received_from": pl.String, + "product_balance": pl.String, + "product_sheet_name": pl.String, + }, + ) + + out = _clean_received_from(df) + + assert out["product_balance"].to_list() == ["100", "200"] + assert out["product_received_from"].to_list() == ["DKSH", None] + + +def test_clean_received_from_preserves_balance_on_sheets_without_start(): + """Two-sheet frame, START only in sheet A. Sheet B's pre-existing + product_balance must be preserved — a START in one sheet must not + propagate balance assignments (or wipes) to other sheets.""" + df = pl.DataFrame( + { + "product_units_received": ["START", "5", "3", "2"], + "product_received_from": ["DKSH", "DKSH", "Acme", "BMS"], + "product_balance": [None, None, "50", "75"], + "product_sheet_name": ["Jan", "Jan", "Feb", "Feb"], + }, + schema={ + "product_units_received": pl.String, + "product_received_from": pl.String, + "product_balance": pl.String, + "product_sheet_name": pl.String, + }, + ) + + out = _clean_received_from(df) + + # Sheet Jan: START row copies received_from to balance; other Jan row keeps null. + # Sheet Feb: no START → both rows preserve their existing balance ("50", "75"). + assert out["product_balance"].to_list() == ["DKSH", None, "50", "75"] + + +def test_format_dates_normalizes_separator_typos(): + """Cells with non-standard separators (period after day, underscore + before year, repeated dashes) parse to the intended date after + pre-normalization (regression for 2020 CDA Feb20, 2024 Likas Jan24, + 2022 Mukdahan Apr22). The Excel-datetime time-strip ('YYYY-MM-DD + HH:MM:SS') and the canonical 'dd-Mon-yyyy' format must continue to + parse correctly.""" + df = pl.DataFrame( + { + "product": ["P1", "P2", "P3", "P4", "P5"], + "product_entry_date": [ + "24.Feb 2020", # CDA period-after-day typo + "22-Jan_2024", # Likas underscore-before-year typo + "16-Apr--2022", # Mukdahan double-dash typo + "2020-02-24 00:00:00", # Excel datetime cast — time must strip + "16-Apr-2022", # canonical dd-Mon-yyyy — must parse + ], + }, + schema={ + "product": pl.String, + "product_entry_date": pl.String, + }, + ) + + out = _format_dates(df, ErrorCollector()) + + assert out["product_entry_date"].to_list() == [ + date(2020, 2, 24), + date(2024, 1, 22), + date(2022, 4, 16), + date(2020, 2, 24), + date(2022, 4, 16), + ] + + +def test_format_dates_preserves_year_typo_sentinels(): + """Year typos (real Excel datetime with wrong year, 5-digit year + strings) must continue to fail parsing — separator normalization + must not accidentally recover them. Mahosot's datetime(2009,12,4) + parses to a real 2009 date (Python validates it later via + _validate_entry_dates); NPH's 5-digit year falls through to the + error path.""" + df = pl.DataFrame( + { + "product": ["Mahosot", "NPH"], + "product_entry_date": [ + "2009-12-04 00:00:00", # Mahosot Excel datetime, year typo + "1 jun 20203", # NPH 5-digit year + ], + }, + schema={ + "product": pl.String, + "product_entry_date": pl.String, + }, + ) + + out = _format_dates(df, ErrorCollector()) + + parsed = out["product_entry_date"].to_list() + # Mahosot: real datetime parses to 2009-12-04 (year-floor catches later). + assert parsed[0] == date(2009, 12, 4) + # NPH: 5-digit year is unparseable; parse_date_column emits a sentinel + # error_val_date or None — either way, not a valid 2020/2023 date. + assert parsed[1] is None or parsed[1] != date(2023, 6, 1) + assert parsed[1] != date(2020, 6, 1) + + +def test_split_multi_product_cells_extracts_box_count_to_units_received(): + """Pins the \\d+ deviation from R's [1-9]+: a count starting with the + digit '1' followed by '0' must be extracted whole as '10', not '1'.""" + df = pl.DataFrame( + { + "product": ["Accu-Check Strips (10 box)"], + "product_received_from": ["DKSH"], + "product_released_to": [None], + "product_units_received": [None], + "product_units_released": [None], + }, + schema={ + "product": pl.String, + "product_received_from": pl.String, + "product_released_to": pl.String, + "product_units_received": pl.String, + "product_units_released": pl.String, + }, + ) + + out = _split_multi_product_cells(df) + + assert out["product_units_notes"].to_list() == ["10 box"] + assert out["product_units_received"].to_list() == ["10"] + assert out["product_units_released"].to_list() == [None] + + +def test_split_multi_product_cells_extracts_unit_count_to_released_when_released_to_set(): + """Same shape as the box test but routes the count to product_units_released + when product_released_to is set and product_received_from is null.""" + df = pl.DataFrame( + { + "product": ["Accu-Check Strips (10 unit)"], + "product_received_from": [None], + "product_released_to": ["KD_EW001"], + "product_units_received": [None], + "product_units_released": [None], + }, + schema={ + "product": pl.String, + "product_received_from": pl.String, + "product_released_to": pl.String, + "product_units_received": pl.String, + "product_units_released": pl.String, + }, + ) + + out = _split_multi_product_cells(df) + + assert out["product_units_notes"].to_list() == ["10 unit"] + assert out["product_units_released"].to_list() == ["10"] + assert out["product_units_received"].to_list() == [None] + + +def test_misplaced_datetime_cell_in_units_received_logs_and_recodes_to_zero(): + """Pin Python's correct handling of a date typo'd into the + units_received column. R's read_excel reads the underlying Excel + serial (e.g. 43644.0 for 2019-06-28), polluting balance trajectories + by tens of thousands. Python's openpyxl path stringifies, the + Float64 cast in step 2.11 fails, step 2.12 recodes the null to 0.0. + Step 2.11 also emits one ``type_conversion`` ErrorCollector entry — + R-parity with ``preparing_product_fields``'s ``invalid_value`` warning. + Regression for the 22 real-divergence rows in + 2019_Sultanah Bahiyah_DC / 2019_Penang_DC / 2020 Mandalay + surfaced by Ali_internship/product_balance_diff_v1.ipynb.""" + df = pl.DataFrame( + { + "file_name": ["2019_SB.xlsx"], + "product_sheet_name": ["Jun19"], + "product": ["Accu-Chek Performa Test Strip"], + "product_entry_date": [None], + "product_units_received": ["2019-06-28 00:00:00"], # typo cell + "product_units_released": [None], + "product_received_from": [None], + "product_released_to": [None], + "product_units_returned": [None], + "product_returned_by": [None], + "product_balance": [None], + "product_table_month": ["06"], + "product_table_year": [2019.0], + "product_balance_status": [None], + "product_category": [None], + "product_unit_capacity": [None], + "product_units_notes": [None], + "orig_product_released_to": [None], + "product_remarks": [None], + }, + schema={ + "file_name": pl.String, + "product_sheet_name": pl.String, + "product": pl.String, + "product_entry_date": pl.String, + "product_units_received": pl.String, + "product_units_released": pl.String, + "product_received_from": pl.String, + "product_released_to": pl.String, + "product_units_returned": pl.String, + "product_returned_by": pl.String, + "product_balance": pl.String, + "product_table_month": pl.String, + "product_table_year": pl.Float64, + "product_balance_status": pl.String, + "product_category": pl.String, + "product_unit_capacity": pl.String, + "product_units_notes": pl.String, + "orig_product_released_to": pl.String, + "product_remarks": pl.String, + }, + ) + + collector = ErrorCollector() + out = clean_product_data(df, collector) + + assert out["product_units_received"].to_list() == [0.0] + + units_errors = [ + e for e in collector.errors + if e.column == "product_units_received" and e.function_name == "_clean_units_received" + ] + assert len(units_errors) == 1 + err = units_errors[0] + assert err.error_code == "type_conversion" + assert err.original_value == "2019-06-28 00:00:00" + assert err.file_name == "2019_SB.xlsx" + assert err.patient_id == "Accu-Chek Performa Test Strip" + + +def test_running_balance_eliminates_float_residue(): + """Pin .round(10) on the cumsum result. Without it, Python's vectorized + cumsum yields 2.220446e-16 instead of 0 on rows where the running delta + sum hits a non-binary-clean target (e.g. 1.8 - 0.4 - 1.4 = 2.22e-16 + because 0.4+1.4 in float64 is 1.7999999999999998). R's row-by-row + recurrence avoids this. Regression for the 600 FP-noise rows surfaced + by Ali_internship/product_balance_diff_v3.ipynb (V3 corpus example: + 2021_Lao Friends, Feb21, Mixtard 30 Penfill 3ml (5s)).""" + df = pl.DataFrame( + { + "product_sheet_name": ["Feb21"] * 4, + "product": ["Mixtard"] * 4, + "product_balance": [1.8, None, None, None], + "product_balance_status": ["start", "change", "change", "end"], + "product_units_received": [0.0, 0.0, 0.0, 0.0], + "product_units_released": [0.0, 0.4, 1.4, 0.0], + "product_units_returned": [0.0, 0.0, 0.0, 0.0], + "product_table_year": [2020, 2020, 2020, 2020], # < 2021 to skip end-row zeroing + }, + schema={ + "product_sheet_name": pl.String, + "product": pl.String, + "product_balance": pl.Float64, + "product_balance_status": pl.String, + "product_units_received": pl.Float64, + "product_units_released": pl.Float64, + "product_units_returned": pl.Float64, + "product_table_year": pl.Int32, + }, + ) + + out = _compute_running_balance(df) + + # Strict equality — without .round(10) the 3rd row would be 2.22e-16. + assert out["product_balance"].to_list() == [1.8, 1.4, 0.0, 0.0] + + +def test_fill_product_names_and_sort_treats_sentinel_as_null_for_rank(): + """Pin the sentinel-rank fix. The 9999-09-09 parse-failure sentinel must + be treated as null when computing the per-(sheet, product) rank, so it + falls into the 'preserve input order' branch rather than sorting to the + dense_d+1 end-of-changes position. Real-world driver: the 11 Tier-2 rows + in product_balance documented in product_balance_investigation.md, where + R returns null for unparseable dates and Python returns 9999-09-09 — the + sentinel was sorting after every valid date, distorting the cumulative + balance order vs. R.""" + df = pl.DataFrame( + { + "product_sheet_name": ["Jun24"] * 5, + "product": ["P"] * 5, + "product_entry_date": [None, date(2024, 1, 15), date(9999, 9, 9), date(2024, 3, 15), None], + "product_balance_status": ["start", "change", "change", "change", "end"], + "product_table_month": [6] * 5, + "index": [1, 2, 3, 4, 5], + }, + schema={ + "product_sheet_name": pl.String, + "product": pl.String, + "product_entry_date": pl.Date, + "product_balance_status": pl.String, + "product_table_month": pl.Int32, + "index": pl.Int64, + }, + ) + + out = _fill_product_names_and_sort(df) + + # After fix: sentinel sits between the two valid mid-rows in input order + # (rank=row_n=3 ties with rank for 2024-03-15=dense_d+1=3, stable sort + # keeps the sentinel at input position 3, valid date at position 4). + statuses = out["product_balance_status"].to_list() + assert statuses == ["start", "change", "change", "change", "end"] + dates = out["product_entry_date"].to_list() + assert dates[0] is None # start + assert dates[1] == date(2024, 1, 15) # earlier valid date + assert dates[2] == date(9999, 9, 9) # sentinel — between A and B + assert dates[3] == date(2024, 3, 15) # later valid date + assert dates[4] is None # end + + +def _check_dates_df( + entry_dates: list[date | None], + table_month: int = 6, + table_year: int = 2024, +) -> pl.DataFrame: + """Fixture for _check_entry_dates_match_sheet. Mirrors the post-2.6 schema: + table_month is still String, table_year is still Float64 (the schema cast + at step 2.16 hasn't run yet).""" + n = len(entry_dates) + return pl.DataFrame( + { + "product": [f"P{i}" for i in range(n)], + "product_entry_date": entry_dates, + "product_table_year": [float(table_year)] * n, + "product_table_month": [f"{table_month:02d}"] * n, + "product_sheet_name": [f"{table_month:02d}"] * n, + "file_name": ["t.xlsx"] * n, + }, + schema={ + "product": pl.String, + "product_entry_date": pl.Date, + "product_table_year": pl.Float64, + "product_table_month": pl.String, + "product_sheet_name": pl.String, + "file_name": pl.String, + }, + ) + + +def test_check_entry_dates_logs_month_mismatch(): + df = _check_dates_df( + entry_dates=[date(2024, 6, 15), date(2024, 3, 15)], + table_month=6, + table_year=2024, + ) + collector = ErrorCollector() + + _check_entry_dates_match_sheet(df, collector) + + assert len(collector) == 1 + err = collector.errors[0] + assert err.column == "product_entry_date" + assert err.error_code == "invalid_value" + assert err.function_name == "check_entry_dates" + assert err.patient_id == "P1" + + +def test_check_entry_dates_logs_year_mismatch(): + df = _check_dates_df( + entry_dates=[date(2023, 6, 15)], # year mismatch even though month matches + table_month=6, + table_year=2024, + ) + collector = ErrorCollector() + _check_entry_dates_match_sheet(df, collector) + assert len(collector) == 1 + + +def test_check_entry_dates_skips_sentinel_buddhist_and_null(): + df = _check_dates_df( + entry_dates=[ + None, # null — skip + date(9999, 9, 9), # parse-failure sentinel — skip + date(2567, 11, 11), # Buddhist-era — skip + date(2024, 6, 15), # match — skip + ], + table_month=6, + table_year=2024, + ) + collector = ErrorCollector() + _check_entry_dates_match_sheet(df, collector) + assert len(collector) == 0 + + +def test_check_entry_dates_no_log_on_match(): + df = _check_dates_df( + entry_dates=[date(2024, 6, 1), date(2024, 6, 30)], + table_month=6, + table_year=2024, + ) + collector = ErrorCollector() + _check_entry_dates_match_sheet(df, collector) + assert len(collector) == 0 + + +def test_fat_finger_future_falls_into_input_order_rank_after_validation(): + """Pin the deliberate _validate_entry_dates × _fill_product_names_and_sort + interaction. R keeps `2099-03-15` as a valid future date and ranks it by + value; Python clobbers it to the 9999-09-09 sentinel inside + _validate_entry_dates and then treats the sentinel as null when ranking. + + This is a documented divergence from R. The test exists so any future + revert of either _validate_entry_dates' future-year guard or the + sentinel-rank fix in _fill_product_names_and_sort breaks loudly.""" + collector = ErrorCollector() + df = pl.DataFrame( + { + "product": ["P", "P", "P", "P"], + "product_entry_date": [None, date(2099, 3, 15), date(2024, 6, 1), None], + "product_balance_status": ["start", "change", "change", "end"], + "product_table_year": [2024] * 4, + "product_table_month": [6] * 4, + "product_sheet_name": ["Jun24"] * 4, + "file_name": ["t.xlsx"] * 4, + "index": [1, 2, 3, 4], + }, + schema={ + "product": pl.String, + "product_entry_date": pl.Date, + "product_balance_status": pl.String, + "product_table_year": pl.Int32, + "product_table_month": pl.Int32, + "product_sheet_name": pl.String, + "file_name": pl.String, + "index": pl.Int64, + }, + ) + + validated = _validate_entry_dates(df, collector) + out = _fill_product_names_and_sort(validated) + + # _validate_entry_dates rewrote the 2099 future to the sentinel. + assert validated["product_entry_date"].to_list()[1] == date(9999, 9, 9) + assert len(collector) == 1 + + # In the sorted output, the (now-sentinel) row sits at its input position + # 2 — not the dense-rank tail. Statuses confirm row ordering. + assert out["product_balance_status"].to_list() == [ + "start", + "change", + "change", + "end", + ] + dates = out["product_entry_date"].to_list() + assert dates[0] is None # start + assert dates[1] == date(9999, 9, 9) # was 2099-03-15, now sentinel-as-null + assert dates[2] == date(2024, 6, 1) # valid mid-row + assert dates[3] is None # end diff --git a/tests/test_clean/test_validators.py b/tests/test_clean/test_validators.py index d662181..84c3cbc 100644 --- a/tests/test_clean/test_validators.py +++ b/tests/test_clean/test_validators.py @@ -325,6 +325,66 @@ def test_validate_allowed_values_case_insensitive(): assert len(collector) == 0 # No errors - "y" is valid +def test_validate_allowed_values_csv_subset(): + """CSV values whose every token is in allowed_values emit canonical CSV.""" + allowed = ["Pre-mixed", "Short-acting", "Intermediate-acting", "Rapid-acting", "Long-acting"] + df = pl.DataFrame( + { + "file_name": ["test.xlsx"] * 5, + "patient_id": ["XX_YY001", "XX_YY002", "XX_YY003", "XX_YY004", "XX_YY005"], + "insulin_subtype": [ + "rapid-acting", # single valid + "pre-mixed,rapid-acting", # two-token CSV valid + "pre-mixed,rapid-acting,long-acting", # three-token CSV valid + "pre-mixed,unknown-thing", # CSV with one bad token + "not-in-list", # single invalid + ], + } + ) + + collector = ErrorCollector() + result = validate_allowed_values( + df=df, + column="insulin_subtype", + allowed_values=allowed, + error_collector=collector, + replace_invalid=True, + allow_csv_subset=True, + ) + + assert result["insulin_subtype"].to_list() == [ + "Rapid-acting", + "Pre-mixed,Rapid-acting", + "Pre-mixed,Rapid-acting,Long-acting", + settings.error_val_character, + settings.error_val_character, + ] + assert len(collector) == 2 + + +def test_validate_allowed_values_csv_subset_disabled(): + """Without allow_csv_subset, CSV values are treated as single strings and fail.""" + allowed = ["Pre-mixed", "Rapid-acting"] + df = pl.DataFrame( + { + "file_name": ["test.xlsx"], + "patient_id": ["XX_YY001"], + "insulin_subtype": ["pre-mixed,rapid-acting"], + } + ) + + collector = ErrorCollector() + result = validate_allowed_values( + df=df, + column="insulin_subtype", + allowed_values=allowed, + error_collector=collector, + replace_invalid=True, + ) + + assert result["insulin_subtype"].to_list() == [settings.error_val_character] + + # Tests for fix_patient_id diff --git a/tests/test_cli/test_cli.py b/tests/test_cli/test_cli.py index 5c3baea..d6a40b5 100644 --- a/tests/test_cli/test_cli.py +++ b/tests/test_cli/test_cli.py @@ -43,6 +43,7 @@ def test_run_pipeline_help(self): assert result.exit_code == 0 assert "--skip-download" in result.output assert "--skip-upload" in result.output + assert "--skip-patient" in result.output # --------------------------------------------------------------------------- @@ -141,6 +142,59 @@ def test_pipeline_failure_exits_nonzero(self, mock_settings, mock_run_pipeline, assert result.exit_code == 1 + @patch("a4d.cli.run_product_pipeline") + @patch("a4d.cli.run_patient_pipeline") + @patch("a4d.config.settings") + def test_skip_patient_runs_product_only( + self, mock_settings, mock_run_patient, mock_run_product, tmp_path + ): + mock_settings.data_root = tmp_path / "data" + mock_settings.output_root = tmp_path / "output" + mock_settings.project_id = "test-project" + mock_settings.dataset = "test-dataset" + mock_settings.max_workers = 4 + + (tmp_path / "data").mkdir() + (tmp_path / "output").mkdir() + + mock_product = MagicMock() + mock_product.success = True + mock_product.total_trackers = 0 + mock_product.successful_trackers = 0 + mock_product.failed_trackers = 0 + mock_product.tracker_results = [] + mock_run_product.return_value = mock_product + + result = runner.invoke( + app, + [ + "run-pipeline", + "--skip-patient", + "--skip-download", + "--skip-upload", + "--skip-drive-download", + ], + ) + + assert result.exit_code == 0, f"Pipeline failed:\n{result.output}" + mock_run_patient.assert_not_called() + mock_run_product.assert_called_once() + + def test_skip_patient_and_skip_product_mutually_exclusive(self, tmp_path): + result = runner.invoke( + app, + [ + "run-pipeline", + "--skip-patient", + "--skip-product", + "--skip-download", + "--skip-upload", + "--skip-drive-download", + ], + ) + assert result.exit_code == 1 + assert "mutually exclusive" in result.output.lower() + # --------------------------------------------------------------------------- # End-to-end test: process-patient with real dummy tracker diff --git a/tests/test_cli/test_force_incremental.py b/tests/test_cli/test_force_incremental.py new file mode 100644 index 0000000..9554ef8 --- /dev/null +++ b/tests/test_cli/test_force_incremental.py @@ -0,0 +1,162 @@ +"""Tests for --force and --incremental flag interaction across CLI commands.""" + +import hashlib +from pathlib import Path + +from typer.testing import CliRunner + +from a4d.cli import app + +runner = CliRunner(env={"NO_COLOR": "1", "COLUMNS": "200"}) + + +def _hash_dir(dir_path: Path) -> dict[str, str]: + """Return {relative_path: sha256} for the deterministic data outputs. + + Restricted to ``patient_data_cleaned/*.parquet`` and the patient-table + parquets. Skips ``logs/`` (per-tracker JSON contains timestamps) and + ``table_logs.parquet`` (aggregated log timestamps). + """ + out: dict[str, str] = {} + cleaned = dir_path / "patient_data_cleaned" + if cleaned.exists(): + for f in sorted(cleaned.glob("*.parquet")): + out[f"patient_data_cleaned/{f.name}"] = hashlib.sha256(f.read_bytes()).hexdigest() + tables = dir_path / "tables" + if tables.exists(): + for f in sorted(tables.glob("patient_data_*.parquet")): + out[f"tables/{f.name}"] = hashlib.sha256(f.read_bytes()).hexdigest() + return out + + +class TestHelpExposesForce: + """Every command that takes --incremental must also expose --force.""" + + def test_process_patient_help_mentions_force(self): + result = runner.invoke(app, ["process-patient", "--help"]) + assert result.exit_code == 0 + assert "--force" in result.output + + def test_process_product_help_mentions_force(self): + result = runner.invoke(app, ["process-product", "--help"]) + assert result.exit_code == 0 + assert "--force" in result.output + + def test_run_pipeline_help_mentions_force(self): + result = runner.invoke(app, ["run-pipeline", "--help"]) + assert result.exit_code == 0 + assert "--force" in result.output + + +class TestForceFlag: + """--force semantics on process-patient (the one CLI command we can drive end-to-end).""" + + def test_force_produces_same_output_as_default(self, dummy_tracker_dir, tmp_path): + """--force is an explicit synonym for the default — outputs must match byte-for-byte.""" + out_default = tmp_path / "out_default" + out_force = tmp_path / "out_force" + + r1 = runner.invoke( + app, + ["process-patient", "--data-root", str(dummy_tracker_dir), "--output", str(out_default)], + ) + assert r1.exit_code == 0, f"default run failed:\n{r1.output}" + + r2 = runner.invoke( + app, + ["process-patient", "--data-root", str(dummy_tracker_dir), "--output", str(out_force), "--force"], + ) + assert r2.exit_code == 0, f"--force run failed:\n{r2.output}" + + assert _hash_dir(out_default) == _hash_dir(out_force) + + def test_force_is_deterministic(self, dummy_tracker_dir, tmp_path): + """Two consecutive --force runs into separate dirs produce identical outputs.""" + out_a = tmp_path / "a" + out_b = tmp_path / "b" + + for out in (out_a, out_b): + r = runner.invoke( + app, + ["process-patient", "--data-root", str(dummy_tracker_dir), "--output", str(out), "--force"], + ) + assert r.exit_code == 0, f"run into {out} failed:\n{r.output}" + + assert _hash_dir(out_a) == _hash_dir(out_b) + + def test_force_wipes_existing_outputs(self, dummy_tracker_dir, tmp_path): + """--force must trigger the orchestrator's clean_output=True wipe. + + Pre-seed sentinel files in all four dirs the orchestrator wipes + (patient_data_raw, patient_data_cleaned, tables, logs — see + pipeline/patient.py:176) and confirm they're gone after --force runs. + Without this check the determinism tests would pass even if --force + silently became a no-op (since they hash output equivalence into fresh + empty dirs, not the wipe behavior itself). + """ + out = tmp_path / "out" + wipe_dirs = ("patient_data_raw", "patient_data_cleaned", "tables", "logs") + for subdir in wipe_dirs: + (out / subdir).mkdir(parents=True) + (out / subdir / "_sentinel").write_text("from prior run") + + result = runner.invoke( + app, + ["process-patient", "--data-root", str(dummy_tracker_dir), "--output", str(out), "--force"], + ) + assert result.exit_code == 0, f"--force run failed:\n{result.output}" + + for subdir in wipe_dirs: + assert not (out / subdir / "_sentinel").exists(), ( + f"--force did not wipe {subdir}/_sentinel" + ) + + +class TestForceIncrementalConflict: + """--force overrides --incremental with a warning.""" + + def test_warning_printed_when_both_set(self, dummy_tracker_dir, tmp_path): + out = tmp_path / "out" + result = runner.invoke( + app, + [ + "process-patient", + "--data-root", str(dummy_tracker_dir), + "--output", str(out), + "--force", + "--incremental", + ], + ) + assert result.exit_code == 0, f"combined run failed:\n{result.output}" + assert "--incremental is ignored when --force is set" in result.output + + def test_combined_output_matches_force_alone(self, dummy_tracker_dir, tmp_path): + """--force --incremental should behave exactly like --force (incremental ignored).""" + out_force = tmp_path / "force" + out_both = tmp_path / "both" + + r_force = runner.invoke( + app, + ["process-patient", "--data-root", str(dummy_tracker_dir), "--output", str(out_force), "--force"], + ) + assert r_force.exit_code == 0 + + r_both = runner.invoke( + app, + [ + "process-patient", + "--data-root", str(dummy_tracker_dir), + "--output", str(out_both), + "--force", + "--incremental", + ], + ) + assert r_both.exit_code == 0 + + assert _hash_dir(out_force) == _hash_dir(out_both) + + +# Note: a "second incremental run skips unchanged" test would need to seed the +# tracker_metadata manifest first — process-patient does not create it (only +# create-tables / run-pipeline do). That code path is exercised by +# tests/test_state/test_integration.py, so we don't duplicate it here. diff --git a/tests/test_extract/test_common.py b/tests/test_extract/test_common.py new file mode 100644 index 0000000..c13f42c --- /dev/null +++ b/tests/test_extract/test_common.py @@ -0,0 +1,90 @@ +"""Unit tests for shared tracker-level extraction helpers.""" + +from pathlib import Path + +import polars as pl +import pytest + +from a4d.extract.common import ( + clean_excel_errors, + extract_tracker_month, + find_month_sheets, + get_tracker_year, +) + + +class _StubWorkbook: + """Minimal stand-in for openpyxl.Workbook that exposes ``sheetnames``.""" + + def __init__(self, sheetnames: list[str]): + self.sheetnames = sheetnames + + +def test_get_tracker_year_from_sheet_names(): + year = get_tracker_year(Path("anything.xlsx"), ["Jan24", "Feb24", "Mar24"]) + assert year == 2024 + + +def test_get_tracker_year_falls_back_to_filename(): + year = get_tracker_year(Path("2023_clinic_tracker.xlsx"), ["January", "February"]) + assert year == 2023 + + +def test_get_tracker_year_raises_when_unparseable(): + with pytest.raises(ValueError): + get_tracker_year(Path("clinic_tracker.xlsx"), ["January", "February"]) + + +def test_find_month_sheets_filters_and_sorts(): + wb = _StubWorkbook(["Cover", "Mar24", "Jan24", "Feb24", "Notes"]) + sheets = find_month_sheets(wb) + assert sheets == ["Jan24", "Feb24", "Mar24"] + + +def test_clean_excel_errors_replaces_with_null(): + df = pl.DataFrame( + {"bmi": ["17.5", "#DIV/0!", "18.2", "#VALUE!"]}, + schema={"bmi": pl.String}, + ) + cleaned = clean_excel_errors(df) + assert cleaned["bmi"].to_list() == ["17.5", None, "18.2", None] + + +def test_clean_excel_errors_skips_non_string_columns(): + df = pl.DataFrame( + {"product_table_year": [2024.0, 2024.0]}, + schema={"product_table_year": pl.Float64}, + ) + cleaned = clean_excel_errors(df) + assert cleaned.equals(df) + + +def test_extract_tracker_month_known_prefix(): + assert extract_tracker_month("Jan24") == 1 + assert extract_tracker_month("Dec23") == 12 + + +def test_extract_tracker_month_unknown_prefix(): + with pytest.raises(ValueError): + extract_tracker_month("Foo24") + + +def test_re_export_from_patient_module_still_works(): + """Backward-compat re-export shim stays callable from a4d.extract.patient.""" + from a4d.extract.patient import ( + clean_excel_errors as p_clean_excel_errors, + ) + from a4d.extract.patient import ( + extract_tracker_month as p_extract_tracker_month, + ) + from a4d.extract.patient import ( + find_month_sheets as p_find_month_sheets, + ) + from a4d.extract.patient import ( + get_tracker_year as p_get_tracker_year, + ) + + assert p_clean_excel_errors is clean_excel_errors + assert p_extract_tracker_month is extract_tracker_month + assert p_find_month_sheets is find_month_sheets + assert p_get_tracker_year is get_tracker_year diff --git a/tests/test_extract/test_product.py b/tests/test_extract/test_product.py new file mode 100644 index 0000000..d3db878 --- /dev/null +++ b/tests/test_extract/test_product.py @@ -0,0 +1,385 @@ +"""Unit tests for product data extraction (`a4d.extract.product`).""" + +from pathlib import Path +from unittest.mock import Mock + +import polars as pl +import pytest +from openpyxl import Workbook + +from a4d.errors import ErrorCollector +from a4d.extract.product import ( + ProductSectionNotFoundError, + _count_orphan_released_units, + _harmonize, + add_product_metadata, + extract_product_data, + find_product_section, + read_all_product_sheets, + remove_header_rows, + replace_extra_totals, +) + + +def _make_ws(rows: list[list]): + """Build an in-memory openpyxl worksheet with the given rows. + + Returns the worksheet handle (1-indexed Excel rows). Empty cells use + ``None``. + """ + wb = Workbook() + ws = wb.active + for row in rows: + ws.append(row) + return ws + + +def test_find_product_section_2024_layout(): + """Header row contains product/date/received; patient section follows.""" + ws = _make_ws( + [ + [None, None, None, None, None], + ["Product", "Date", "Units Received", "From", "Released"], + ["Insulin", "2024-06-01", 100, "DKSH", 5], + ["Strips", "2024-06-02", 50, "DKSH", None], + ["Patient Recruitment", None, None, None, None], + ["Patient Name", "Patient ID", None, None, None], + ] + ) + start, end = find_product_section(ws) + assert start == 2 + assert end == 4 + + +def test_find_product_section_raises_on_missing_start(): + """No header row with product/date/received markers triggers the error.""" + ws = _make_ws( + [ + ["Random", "Junk", "Headers"], + ["Patient Name", "Patient ID", None], + ] + ) + with pytest.raises(ProductSectionNotFoundError): + find_product_section(ws) + + +def test_find_product_section_raises_on_missing_end(): + """Header found but no patient-section terminator triggers the error.""" + ws = _make_ws( + [ + ["Product", "Date", "Units Received"], + ["Insulin", "2024-06-01", 100], + ] + ) + with pytest.raises(ProductSectionNotFoundError): + find_product_section(ws) + + +def _make_mapper(known_to_standard: dict[str, str]): + """Build a mock ColumnMapper. + + ``known_to_standard`` maps source column names to the standardized name + they should be renamed to. Anything not in the dict is treated as + unknown (caught by ``is_known_column`` returning False) and dropped. + """ + mapper = Mock() + mapper.synonyms = {standard: [src] for src, standard in known_to_standard.items()} + mapper.is_known_column = lambda col: col in known_to_standard + mapper.rename_columns = lambda df: df.rename(known_to_standard) + return mapper + + +def test_harmonize_renames_known_drops_unknown(): + df = pl.DataFrame( + {"Product Name": ["Insulin"], "Random Junk": ["x"]}, + schema={"Product Name": pl.String, "Random Junk": pl.String}, + ) + mapper = _make_mapper({"Product Name": "product"}) + collector = ErrorCollector() + + out = _harmonize(df, mapper, sheet_name="Jun24", error_collector=collector, file_name="t.xlsx") + + assert out.columns == ["product"] + assert len(collector) == 1 + assert collector.errors[0].column == "Random Junk" + assert collector.errors[0].error_code == "invalid_tracker" + assert collector.errors[0].function_name == "harmonize_input_data_columns" + + +def test_harmonize_no_unknowns_no_log(): + df = pl.DataFrame( + {"Product Name": ["Insulin"]}, + schema={"Product Name": pl.String}, + ) + mapper = _make_mapper({"Product Name": "product"}) + collector = ErrorCollector() + + out = _harmonize(df, mapper, sheet_name="Jun24", error_collector=collector, file_name="t.xlsx") + + assert out.columns == ["product"] + assert len(collector) == 0 + + +def test_replace_extra_totals_masks_after_total_column(): + """If a column to the immediate left of product_units_released contains + 'total' (case-insensitive), the released value is nulled.""" + df = pl.DataFrame( + { + "product": ["A", "B"], + "Total Released": ["Total", None], + "product_units_released": ["10", "20"], + }, + schema={ + "product": pl.String, + "Total Released": pl.String, + "product_units_released": pl.String, + }, + ) + out = replace_extra_totals(df) + assert out["product_units_released"].to_list() == [None, "20"] + + +def test_replace_extra_totals_preserves_clean_rows(): + """No 'total' marker in either of the two preceding columns leaves + product_units_released unchanged.""" + df = pl.DataFrame( + { + "product": ["A", "B"], + "product_received_from": ["DKSH", None], + "product_released_to": ["P1", "P2"], + "product_units_released": ["10", "20"], + }, + schema={ + "product": pl.String, + "product_received_from": pl.String, + "product_released_to": pl.String, + "product_units_released": pl.String, + }, + ) + out = replace_extra_totals(df) + assert out["product_units_released"].to_list() == ["10", "20"] + + +def test_replace_extra_totals_case_insensitive(): + """'total' / 'TOTAL' / 'Total' all trigger masking.""" + df = pl.DataFrame( + { + "product": ["A"], + "TOTAL": ["TOTAL"], + "product_units_released": ["10"], + }, + schema={ + "product": pl.String, + "TOTAL": pl.String, + "product_units_released": pl.String, + }, + ) + out = replace_extra_totals(df) + assert out["product_units_released"].to_list() == [None] + + +def test_extract_product_data_promotes_headers_and_types_strings(): + """First row becomes column names; all values are coerced to pl.String.""" + rows = [ + ["Cover", None, None], # 1: pre-section + ["Product", "Date", "Units"], # 2: header + ["Insulin", "2024-06-01", 100], # 3 + ["Strips", "2024-06-02", 50], # 4 + ] + ws = _make_ws(rows) + + out = extract_product_data(ws, start_row=2, end_row=4) + + assert out.columns == ["Product", "Date", "Units"] + assert out.dtypes == [pl.String, pl.String, pl.String] + assert out["Product"].to_list() == ["Insulin", "Strips"] + assert out["Units"].to_list() == ["100", "50"] + + +def test_extract_product_data_merges_duplicate_headers_with_comma(): + """When two columns share a header, values are joined with ','.""" + rows = [ + ["Product", "Note", "Note"], + ["Insulin", "A", "B"], + ["Strips", None, "C"], + ] + ws = _make_ws(rows) + + out = extract_product_data(ws, start_row=1, end_row=3) + assert "Note" in out.columns + assert out["Note"].to_list() == ["A,B", "C"] + + +def test_extract_product_data_returns_empty_for_short_section(): + rows = [["Product", "Date"]] # only header, no data + ws = _make_ws(rows) + out = extract_product_data(ws, start_row=1, end_row=1) + assert out.height == 0 + assert out.width == 0 + + +def test_add_product_metadata_appends_five_cols(): + df = pl.DataFrame({"product": ["A"]}, schema={"product": pl.String}) + out = add_product_metadata(df, "Jun24", 6, 2024, "tracker.xlsx", "CL001") + assert out["product_table_month"].to_list() == ["06"] + assert out["product_table_year"].to_list() == [2024.0] + assert out["product_sheet_name"].to_list() == ["Jun24"] + assert out["file_name"].to_list() == ["tracker.xlsx"] + assert out["clinic_id"].to_list() == ["CL001"] + + +def test_remove_header_rows_drops_repeated_header_and_empty(): + df = pl.DataFrame( + { + "product": ["Product", "Insulin", None, "Patient Data Summary"], + "product_entry_date": [None, "2024-06", None, None], + }, + schema={"product": pl.String, "product_entry_date": pl.String}, + ) + out = remove_header_rows(df) + assert out["product"].to_list() == ["Insulin"] + + +def test_remove_header_rows_empty_input(): + df = pl.DataFrame({"product": []}, schema={"product": pl.String}) + out = remove_header_rows(df) + assert out.height == 0 + + +def test_read_all_product_sheets_end_to_end(tmp_path: Path): + """End-to-end: build a tracker workbook with one month sheet, stub the + mapper, and verify the orchestrator wires extract → harmonize → + metadata → totals correctly.""" + wb = Workbook() + # Remove the default sheet and add a Jun24 sheet. + wb.remove(wb.active) + ws = wb.create_sheet("Jun24") + + rows = [ + [None, None, None, None, None], # 1: pre-section + ["Product", "Date", "Units Received", "From", "Released"], # 2: header + ["Insulin", "2024-06-01", 100, "DKSH", 5], # 3: data + ["Strips", "2024-06-02", 50, "DKSH", None], # 4: data + ["Patient Recruitment", None, None, None, None], # 5: terminator + ["Patient Name", "Patient ID", None, None, None], # 6: end + ] + for row in rows: + ws.append(row) + + tracker_path = tmp_path / "2024_test_tracker.xlsx" + wb.save(tracker_path) + + mapper = _make_mapper( + { + "Product": "product", + "Date": "product_entry_date", + "Units Received": "product_units_received", + "From": "product_received_from", + "Released": "product_units_released", + } + ) + collector = ErrorCollector() + + out = read_all_product_sheets(tracker_path, mapper=mapper, error_collector=collector) + + assert out.height == 2 + assert "product" in out.columns + assert "product_table_year" in out.columns + assert "product_sheet_name" in out.columns + assert "clinic_id" in out.columns + assert out["product_sheet_name"].to_list() == ["Jun24", "Jun24"] + assert out["product_table_year"].to_list() == [2024.0, 2024.0] + assert out["clinic_id"].to_list() == [tracker_path.parent.name] * 2 + # No unknown columns; collector should be empty. + assert len(collector) == 0 + + +def test_read_all_product_sheets_no_month_sheets_raises(tmp_path: Path): + wb = Workbook() + # Default sheet has no month-prefix; remove and add a non-month sheet. + wb.remove(wb.active) + wb.create_sheet("Cover") + tracker_path = tmp_path / "2024_no_months.xlsx" + wb.save(tracker_path) + + mapper = _make_mapper({}) + + with pytest.raises(ValueError, match="No month sheets"): + read_all_product_sheets(tracker_path, mapper=mapper) + + +def test_count_orphan_released_units_logs_per_sheet(): + """3 orphan rows (released_to null while units_released non-null) + produce exactly 1 ErrorCollector entry naming the sheet and count.""" + df = pl.DataFrame( + { + "product_released_to": [None, None, None, "P1", "P2"], + "product_units_released": ["10", "20", "30", "40", None], + }, + schema={ + "product_released_to": pl.String, + "product_units_released": pl.String, + }, + ) + collector = ErrorCollector() + + _count_orphan_released_units(df, sheet_name="Jul24", file_name="t.xlsx", error_collector=collector) + + assert len(collector) == 1 + err = collector.errors[0] + assert err.error_code == "invalid_tracker" + assert err.function_name == "read_product_data_step1" + assert err.column == "product_released_to" + assert "Jul24" in err.error_message + assert "3" in err.error_message + + +def test_count_orphan_released_units_zero_rows_no_log(): + """Clean fixture (all releases have a recipient) emits no log.""" + df = pl.DataFrame( + { + "product_released_to": ["P1", "P2", None], + "product_units_released": ["10", "20", None], + }, + schema={ + "product_released_to": pl.String, + "product_units_released": pl.String, + }, + ) + collector = ErrorCollector() + _count_orphan_released_units(df, "Jul24", "t.xlsx", collector) + assert len(collector) == 0 + + +def test_count_orphan_released_units_treats_whitespace_as_null(): + """Whitespace-only product_released_to (e.g. ' ') counts as orphan.""" + df = pl.DataFrame( + { + "product_released_to": [" ", "\t", "P1"], + "product_units_released": ["10", "20", "30"], + }, + schema={ + "product_released_to": pl.String, + "product_units_released": pl.String, + }, + ) + collector = ErrorCollector() + _count_orphan_released_units(df, "Jul24", "t.xlsx", collector) + assert len(collector) == 1 + assert "2" in collector.errors[0].error_message + + +def test_count_orphan_released_units_no_collector_is_noop(): + """When no collector is passed (e.g. preview path), the helper is a no-op.""" + df = pl.DataFrame( + { + "product_released_to": [None], + "product_units_released": ["10"], + }, + schema={ + "product_released_to": pl.String, + "product_units_released": pl.String, + }, + ) + # Just must not raise. + _count_orphan_released_units(df, "Jul24", "t.xlsx", None) diff --git a/tests/test_gcp/test_select_tracker_metadata.py b/tests/test_gcp/test_select_tracker_metadata.py new file mode 100644 index 0000000..729a1d1 --- /dev/null +++ b/tests/test_gcp/test_select_tracker_metadata.py @@ -0,0 +1,77 @@ +"""Tests for select_tracker_metadata.""" + +from unittest.mock import MagicMock, patch + +import pytest +from google.api_core.exceptions import GoogleAPIError, NotFound + +from a4d.gcp.bigquery import select_tracker_metadata + + +def _query_result(rows: list[dict]) -> MagicMock: + """Build a MagicMock that mimics client.query(...).result().""" + job = MagicMock() + job.result.return_value = rows + return job + + +def test_happy_path_returns_dataframe(): + rows = [{"file_name": "T1", "clinic_code": "A", "md5": "abc", "complete": True}] + client = MagicMock() + client.query.return_value = _query_result(rows) + + result = select_tracker_metadata(client=client, dataset="tracker", project_id="proj") + + assert result is not None + assert result.height == 1 + assert result["complete"].to_list() == [True] + + +def test_not_found_returns_none(): + client = MagicMock() + client.query.side_effect = NotFound("table not found") + + result = select_tracker_metadata(client=client, dataset="tracker", project_id="proj") + + assert result is None + + +def test_schema_fallback_when_complete_column_missing(): + """Older deployments may have shipped without the `complete` column.""" + client = MagicMock() + # First query (with `complete`) raises a column-missing GoogleAPIError; + # second query (without `complete`) succeeds with the legacy schema. + fallback_rows = [{"file_name": "T1", "clinic_code": "A", "md5": "abc"}] + client.query.side_effect = [ + GoogleAPIError("Unrecognized name: complete at [1:14]"), + _query_result(fallback_rows), + ] + + result = select_tracker_metadata(client=client, dataset="tracker", project_id="proj") + + assert result is not None + assert result["complete"].to_list() == [False], ( + "Schema fallback must synthesise complete=False to force a full reprocess" + ) + assert client.query.call_count == 2 + + +def test_unrelated_google_api_error_returns_none(): + client = MagicMock() + client.query.side_effect = GoogleAPIError("Network unreachable") + + result = select_tracker_metadata(client=client, dataset="tracker", project_id="proj") + + assert result is None + + +def test_auth_failure_during_client_construction_returns_none(): + """When no client is supplied and `get_bigquery_client` raises (e.g. missing + credentials), the function must fall through with a warning, not raise.""" + with patch( + "a4d.gcp.bigquery.get_bigquery_client", + side_effect=Exception("DefaultCredentialsError: no creds"), + ): + result = select_tracker_metadata(dataset="tracker", project_id="proj") + + assert result is None diff --git a/tests/test_state/__init__.py b/tests/test_state/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_state/test_filter.py b/tests/test_state/test_filter.py new file mode 100644 index 0000000..e20b430 --- /dev/null +++ b/tests/test_state/test_filter.py @@ -0,0 +1,122 @@ +"""Tests for filter_unchanged_trackers.""" + +import hashlib +from pathlib import Path + +from a4d.state.filter import filter_unchanged_trackers +from a4d.state.manifest import Manifest, ManifestEntry + + +def _md5(payload: bytes) -> str: + return hashlib.md5(payload, usedforsecurity=False).hexdigest() + + +def _make_tracker(parent: Path, clinic: str, name: str, payload: bytes) -> Path: + folder = parent / clinic + folder.mkdir(parents=True, exist_ok=True) + path = folder / f"{name}.xlsx" + path.write_bytes(payload) + return path + + +def test_empty_manifest_queues_everything(tmp_path: Path): + t1 = _make_tracker(tmp_path, "A", "2024_T1", b"alpha") + t2 = _make_tracker(tmp_path, "B", "2024_T2", b"beta") + + queued, summary = filter_unchanged_trackers([t1, t2], Manifest.empty()) + + assert queued == [t1, t2] + assert summary.queued == 2 + assert summary.skipped == 0 + assert summary.new == 2 + assert summary.changed == 0 + assert summary.previously_incomplete == 0 + + +def test_all_match_skips_everything(tmp_path: Path): + t1 = _make_tracker(tmp_path, "A", "2024_T1", b"alpha") + t2 = _make_tracker(tmp_path, "B", "2024_T2", b"beta") + manifest = Manifest( + entries={ + ("A", "2024_T1"): ManifestEntry(md5=_md5(b"alpha"), complete=True), + ("B", "2024_T2"): ManifestEntry(md5=_md5(b"beta"), complete=True), + } + ) + + queued, summary = filter_unchanged_trackers([t1, t2], manifest) + + assert queued == [] + assert summary.queued == 0 + assert summary.skipped == 2 + + +def test_changed_md5_queues(tmp_path: Path): + t1 = _make_tracker(tmp_path, "A", "2024_T1", b"new content") + manifest = Manifest( + entries={("A", "2024_T1"): ManifestEntry(md5=_md5(b"old content"), complete=True)}, + ) + + queued, summary = filter_unchanged_trackers([t1], manifest) + + assert queued == [t1] + assert summary.changed == 1 + assert summary.new == 0 + + +def test_previously_incomplete_queues(tmp_path: Path): + """complete=False means the previous run didn't finish all four output stages.""" + t1 = _make_tracker(tmp_path, "A", "2024_T1", b"alpha") + manifest = Manifest( + entries={("A", "2024_T1"): ManifestEntry(md5=_md5(b"alpha"), complete=False)}, + ) + + queued, summary = filter_unchanged_trackers([t1], manifest) + + assert queued == [t1] + assert summary.previously_incomplete == 1 + assert summary.changed == 0 + assert summary.new == 0 + + +def test_mixed_classification(tmp_path: Path): + new = _make_tracker(tmp_path, "A", "2024_NEW", b"new") + changed = _make_tracker(tmp_path, "A", "2024_CHANGED", b"new bytes") + incomplete = _make_tracker(tmp_path, "A", "2024_INCOMPLETE", b"same") + unchanged = _make_tracker(tmp_path, "A", "2024_UNCHANGED", b"frozen") + + manifest = Manifest( + entries={ + ("A", "2024_CHANGED"): ManifestEntry(md5=_md5(b"old bytes"), complete=True), + ("A", "2024_INCOMPLETE"): ManifestEntry(md5=_md5(b"same"), complete=False), + ("A", "2024_UNCHANGED"): ManifestEntry(md5=_md5(b"frozen"), complete=True), + } + ) + + queued, summary = filter_unchanged_trackers( + [new, changed, incomplete, unchanged], manifest + ) + + assert set(queued) == {new, changed, incomplete} + assert summary.queued == 3 + assert summary.skipped == 1 + assert summary.new == 1 + assert summary.changed == 1 + assert summary.previously_incomplete == 1 + + +def test_clinic_isolation(tmp_path: Path): + """Same file stem in two clinic folders must be treated as two distinct trackers.""" + t_a = _make_tracker(tmp_path, "CLINIC_A", "shared_name", b"clinic A bytes") + t_b = _make_tracker(tmp_path, "CLINIC_B", "shared_name", b"clinic B bytes") + + manifest = Manifest( + entries={ + ("CLINIC_A", "shared_name"): ManifestEntry(md5=_md5(b"clinic A bytes"), complete=True), + ("CLINIC_B", "shared_name"): ManifestEntry(md5=_md5(b"clinic B bytes"), complete=True), + } + ) + + queued, summary = filter_unchanged_trackers([t_a, t_b], manifest) + + assert queued == [] + assert summary.skipped == 2 diff --git a/tests/test_state/test_integration.py b/tests/test_state/test_integration.py new file mode 100644 index 0000000..4994f6a --- /dev/null +++ b/tests/test_state/test_integration.py @@ -0,0 +1,101 @@ +"""Producer/consumer round-trip for the incremental-processing manifest. + +Avoids invoking the patient/product orchestrators (which need real .xlsx +structure) — instead exercises the metadata builder → manifest loader → filter +pipeline against synthetic byte-blobs that look like trackers on disk. +""" + +from pathlib import Path + +import pytest + +from a4d.state.filter import filter_unchanged_trackers +from a4d.state.source import load_previous_manifest +from a4d.tables.metadata import create_table_tracker_metadata + + +@pytest.fixture +def fake_pipeline_run(tmp_path: Path): + """Create a 2-tracker fake data_root + fully-populated output subdirs.""" + data_root = tmp_path / "data" + output_root = tmp_path / "output" + + (data_root / "CLINIC_A").mkdir(parents=True) + t1 = data_root / "CLINIC_A" / "2024_T1_Tracker.xlsx" + t1.write_bytes(b"tracker one bytes") + t2 = data_root / "CLINIC_A" / "2024_T2_Tracker.xlsx" + t2.write_bytes(b"tracker two bytes") + + # Simulate a successful run: every tracker has all four output stages. + for subdir in ( + "patient_data_cleaned", + "patient_data_raw", + "product_data_cleaned", + "product_data_raw", + ): + (output_root / subdir).mkdir(parents=True) + for stem in ("2024_T1_Tracker", "2024_T2_Tracker"): + (output_root / subdir / f"{stem}_dummy.parquet").write_bytes(b"") + + # Publish the manifest — same call the run-pipeline CLI makes at end of run. + create_table_tracker_metadata(data_root, output_root) + + return {"data_root": data_root, "output_root": output_root, "t1": t1, "t2": t2} + + +def test_published_manifest_skips_all_unchanged(fake_pipeline_run): + """Round-trip 1: metadata published as complete=True ⇒ next run queues 0.""" + data_root = fake_pipeline_run["data_root"] + output_root = fake_pipeline_run["output_root"] + + discovered = sorted(data_root.rglob("*.xlsx")) + manifest = load_previous_manifest(output_root, prefer_bigquery=False) + + queued, summary = filter_unchanged_trackers(discovered, manifest) + + assert queued == [] + assert summary.skipped == 2 + assert summary.queued == 0 + + +def test_mutated_tracker_queues_exactly_one(fake_pipeline_run): + """Round-trip 2: change one tracker's bytes ⇒ filter queues exactly it.""" + data_root = fake_pipeline_run["data_root"] + output_root = fake_pipeline_run["output_root"] + t1 = fake_pipeline_run["t1"] + + # Mutate t1's bytes — t2 stays untouched. + t1.write_bytes(b"tracker one with new bytes") + + discovered = sorted(data_root.rglob("*.xlsx")) + manifest = load_previous_manifest(output_root, prefer_bigquery=False) + queued, summary = filter_unchanged_trackers(discovered, manifest) + + assert queued == [t1] + assert summary.changed == 1 + assert summary.skipped == 1 + + +def test_missing_output_stage_marks_incomplete_and_requeues(tmp_path: Path): + """If product_data_cleaned/ is missing for a tracker, complete=False ⇒ requeue.""" + data_root = tmp_path / "data" + output_root = tmp_path / "output" + + (data_root / "CLINIC_A").mkdir(parents=True) + t1 = data_root / "CLINIC_A" / "2024_T1_Tracker.xlsx" + t1.write_bytes(b"alpha") + + # Only three of four stages have output — simulates a previous run that + # crashed during the product arm. + for subdir in ("patient_data_cleaned", "patient_data_raw", "product_data_raw"): + (output_root / subdir).mkdir(parents=True) + (output_root / subdir / "2024_T1_Tracker_dummy.parquet").write_bytes(b"") + + create_table_tracker_metadata(data_root, output_root) + + manifest = load_previous_manifest(output_root, prefer_bigquery=False) + queued, summary = filter_unchanged_trackers([t1], manifest) + + assert queued == [t1] + assert summary.previously_incomplete == 1 + assert summary.changed == 0 diff --git a/tests/test_state/test_manifest.py b/tests/test_state/test_manifest.py new file mode 100644 index 0000000..57dd087 --- /dev/null +++ b/tests/test_state/test_manifest.py @@ -0,0 +1,39 @@ +"""Tests for the Manifest dataclasses.""" + +from a4d.state.manifest import Manifest, ManifestEntry + + +def test_empty_manifest(): + m = Manifest.empty() + assert len(m) == 0 + assert m.get("CLINIC_A", "tracker") is None + + +def test_lookup_by_composite_key(): + entries = { + ("CLINIC_A", "2024_T1"): ManifestEntry(md5="abc123", complete=True), + ("CLINIC_B", "2024_T1"): ManifestEntry(md5="def456", complete=False), + } + m = Manifest(entries=entries) + + assert m.get("CLINIC_A", "2024_T1") == ManifestEntry(md5="abc123", complete=True) + assert m.get("CLINIC_B", "2024_T1") == ManifestEntry(md5="def456", complete=False) + # Same file_name in a third clinic must miss — composite key isolates by clinic. + assert m.get("CLINIC_C", "2024_T1") is None + assert len(m) == 2 + + +def test_manifest_entry_is_frozen(): + """Frozen dataclass: equality + immutability are part of the contract.""" + e1 = ManifestEntry(md5="abc", complete=True) + e2 = ManifestEntry(md5="abc", complete=True) + assert e1 == e2 + + import dataclasses + + try: + e1.md5 = "different" # type: ignore[misc] + except dataclasses.FrozenInstanceError: + pass + else: + raise AssertionError("ManifestEntry must be frozen") diff --git a/tests/test_state/test_source.py b/tests/test_state/test_source.py new file mode 100644 index 0000000..2731bb5 --- /dev/null +++ b/tests/test_state/test_source.py @@ -0,0 +1,87 @@ +"""Tests for load_previous_manifest source-precedence logic.""" + +from pathlib import Path +from unittest.mock import patch + +import polars as pl + +from a4d.state.manifest import Manifest +from a4d.state.source import load_previous_manifest + + +def _manifest_df() -> pl.DataFrame: + return pl.DataFrame( + { + "file_name": ["2024_T1", "2024_T2"], + "clinic_code": ["A", "B"], + "md5": ["abc", "def"], + "complete": [True, False], + } + ) + + +def test_bigquery_first_when_available(tmp_path: Path): + """BQ wins over local parquet when prefer_bigquery=True and BQ returns rows.""" + # Both sources populated; BQ should be picked. + local = tmp_path / "tables" / "tracker_metadata.parquet" + local.parent.mkdir(parents=True) + pl.DataFrame( + { + "file_name": ["LOCAL_ONLY"], + "clinic_code": ["X"], + "md5": ["local_md5"], + "complete": [True], + } + ).write_parquet(local) + + with patch("a4d.gcp.bigquery.select_tracker_metadata", return_value=_manifest_df()): + m = load_previous_manifest(tmp_path) + + assert len(m) == 2 + assert m.get("A", "2024_T1").md5 == "abc" + # Local parquet was NOT consulted. + assert m.get("X", "LOCAL_ONLY") is None + + +def test_falls_back_to_local_parquet_when_bq_returns_none(tmp_path: Path): + local = tmp_path / "tables" / "tracker_metadata.parquet" + local.parent.mkdir(parents=True) + _manifest_df().write_parquet(local) + + with patch("a4d.gcp.bigquery.select_tracker_metadata", return_value=None): + m = load_previous_manifest(tmp_path) + + assert len(m) == 2 + assert m.get("A", "2024_T1").complete is True + assert m.get("B", "2024_T2").complete is False + + +def test_returns_empty_when_neither_source_available(tmp_path: Path): + with patch("a4d.gcp.bigquery.select_tracker_metadata", return_value=None): + m = load_previous_manifest(tmp_path) + + assert m == Manifest.empty() + + +def test_skips_bigquery_when_prefer_bigquery_false(tmp_path: Path): + local = tmp_path / "tables" / "tracker_metadata.parquet" + local.parent.mkdir(parents=True) + _manifest_df().write_parquet(local) + + # patch should never be called. + with patch("a4d.gcp.bigquery.select_tracker_metadata") as mock_bq: + m = load_previous_manifest(tmp_path, prefer_bigquery=False) + assert mock_bq.call_count == 0 + + assert len(m) == 2 + + +def test_handles_corrupt_local_parquet(tmp_path: Path): + local = tmp_path / "tables" / "tracker_metadata.parquet" + local.parent.mkdir(parents=True) + local.write_bytes(b"not a parquet file") + + with patch("a4d.gcp.bigquery.select_tracker_metadata", return_value=None): + m = load_previous_manifest(tmp_path) + + assert m == Manifest.empty() diff --git a/tests/test_tables/test_link_product_patient.py b/tests/test_tables/test_link_product_patient.py new file mode 100644 index 0000000..a8a4d47 --- /dev/null +++ b/tests/test_tables/test_link_product_patient.py @@ -0,0 +1,151 @@ +"""Tests for `link_product_patient` — product↔patient link validation.""" + +from collections.abc import Iterator +from pathlib import Path + +import polars as pl +import pytest +from loguru import logger + +from a4d.tables.product import link_product_patient + + +@pytest.fixture +def captured_warnings() -> Iterator[list[str]]: + """Capture loguru WARNING messages emitted during the test.""" + sink: list[str] = [] + handler_id = logger.add( + lambda msg: sink.append(str(msg)), + level="WARNING", + format="{message}", + ) + yield sink + logger.remove(handler_id) + + +def _write_patient_static(path: Path, rows: list[tuple[str, str]]) -> Path: + """Write a minimal patient_data_static.parquet with (file_name, patient_id) rows.""" + pl.DataFrame( + {"file_name": [r[0] for r in rows], "patient_id": [r[1] for r in rows]} + ).write_parquet(path) + return path + + +def test_all_match_returns_zero_no_warnings( + tmp_path: Path, captured_warnings: list[str] +) -> None: + patient_path = _write_patient_static( + tmp_path / "patient.parquet", + [ + ("tracker_a.xlsx", "KD_EW001"), + ("tracker_a.xlsx", "KD_EW002"), + ("tracker_b.xlsx", "KD_EW003"), + ], + ) + product_df = pl.DataFrame( + { + "file_name": [ + "tracker_a.xlsx", + "tracker_a.xlsx", + "tracker_a.xlsx", + "tracker_b.xlsx", + "tracker_b.xlsx", + ], + "product_released_to": [ + "KD_EW001", + "KD_EW001", + "KD_EW002", + "KD_EW003", + "KD_EW003", + ], + } + ) + + count = link_product_patient(product_df, patient_path) + + assert count == 0 + assert captured_warnings == [] + + +def test_mixed_filters_null_and_sentinel( + tmp_path: Path, captured_warnings: list[str] +) -> None: + patient_path = _write_patient_static( + tmp_path / "patient.parquet", + [("tracker_a.xlsx", "KD_EW001"), ("tracker_a.xlsx", "KD_EW002")], + ) + product_df = pl.DataFrame( + { + "file_name": [ + "tracker_a.xlsx", # match + "tracker_a.xlsx", # match + "tracker_a.xlsx", # match + "tracker_a.xlsx", # mismatch + "tracker_a.xlsx", # mismatch + "tracker_a.xlsx", # null — filtered + "tracker_a.xlsx", # "Undefined" — filtered + ], + "product_released_to": [ + "KD_EW001", + "KD_EW001", + "KD_EW002", + "KD_EW999", + "KD_EW999", + None, + "Undefined", + ], + } + ) + + count = link_product_patient(product_df, patient_path) + + assert count == 2 + # One warning per distinct (file, id) mismatch pair → 1 group → 1 warning + mismatch_warnings = [ + w for w in captured_warnings if "Unmatched product_released_to" in w + ] + assert len(mismatch_warnings) == 1 + assert "tracker_a.xlsx" in mismatch_warnings[0] + assert "KD_EW999" in mismatch_warnings[0] + assert "count=2" in mismatch_warnings[0] + + +def test_cross_file_isolation( + tmp_path: Path, captured_warnings: list[str] +) -> None: + """Same patient_id matches in one file but not in another.""" + patient_path = _write_patient_static( + tmp_path / "patient.parquet", + [("tracker_a.xlsx", "KD_EW001")], + ) + product_df = pl.DataFrame( + { + "file_name": ["tracker_a.xlsx", "tracker_b.xlsx"], + "product_released_to": ["KD_EW001", "KD_EW001"], + } + ) + + count = link_product_patient(product_df, patient_path) + + assert count == 1 + mismatch_warnings = [ + w for w in captured_warnings if "Unmatched product_released_to" in w + ] + assert len(mismatch_warnings) == 1 + assert "tracker_b.xlsx" in mismatch_warnings[0] + assert "KD_EW001" in mismatch_warnings[0] + + +def test_missing_patient_table_returns_zero_with_warning( + tmp_path: Path, captured_warnings: list[str] +) -> None: + product_df = pl.DataFrame( + {"file_name": ["tracker_a.xlsx"], "product_released_to": ["KD_EW001"]} + ) + missing_path = tmp_path / "does_not_exist.parquet" + + count = link_product_patient(product_df, missing_path) + + assert count == 0 + skip_warnings = [w for w in captured_warnings if "skipping" in w.lower()] + assert len(skip_warnings) == 1 diff --git a/tests/test_tables/test_metadata.py b/tests/test_tables/test_metadata.py new file mode 100644 index 0000000..bd91db2 --- /dev/null +++ b/tests/test_tables/test_metadata.py @@ -0,0 +1,150 @@ +"""Tests for tracker metadata table generation.""" + +import hashlib +from pathlib import Path + +import polars as pl +import pytest + +from a4d.tables.metadata import create_table_tracker_metadata + + +@pytest.fixture +def fake_trackers(tmp_path: Path) -> Path: + """Create a fake data_root with two clinic subfolders and three trackers.""" + data_root = tmp_path / "data" + (data_root / "CLINIC_A").mkdir(parents=True) + (data_root / "CLINIC_B").mkdir(parents=True) + + (data_root / "CLINIC_A" / "2024_A1_Tracker.xlsx").write_bytes(b"clinic A tracker 1 contents") + (data_root / "CLINIC_A" / "2024_A2_Tracker.xlsx").write_bytes(b"clinic A tracker 2 contents") + (data_root / "CLINIC_B" / "2024_B1_Tracker.xlsx").write_bytes(b"clinic B tracker contents") + return data_root + + +@pytest.fixture +def fake_output_root(tmp_path: Path) -> Path: + """Create an output_root where A1 is fully processed, A2 half-processed, B1 untouched.""" + output_root = tmp_path / "output" + for subdir in ( + "patient_data_cleaned", + "patient_data_raw", + "product_data_cleaned", + "product_data_raw", + ): + (output_root / subdir).mkdir(parents=True) + + # A1: full set of outputs across all four subdirs. + for subdir, suffix in ( + ("patient_data_cleaned", "_patient_cleaned.parquet"), + ("patient_data_raw", "_patient_raw.parquet"), + ("product_data_cleaned", "_product_cleaned.parquet"), + ("product_data_raw", "_product_raw.parquet"), + ): + (output_root / subdir / f"2024_A1_Tracker{suffix}").write_bytes(b"") + + # A2: only patient outputs (product arm skipped). + (output_root / "patient_data_raw" / "2024_A2_Tracker_patient_raw.parquet").write_bytes(b"") + (output_root / "patient_data_cleaned" / "2024_A2_Tracker_patient_cleaned.parquet").write_bytes(b"") + + # B1: no outputs. + return output_root + + +def test_metadata_shape_and_schema(fake_trackers: Path, fake_output_root: Path): + out_path = create_table_tracker_metadata(fake_trackers, fake_output_root) + + assert out_path == fake_output_root / "tables" / "tracker_metadata.parquet" + assert out_path.exists() + + df = pl.read_parquet(out_path) + assert df.shape == (3, 9) + assert df.columns == [ + "file_name", + "clinic_code", + "md5", + "patient_data_cleaned", + "patient_data_raw", + "product_data_cleaned", + "product_data_raw", + "complete", + "timestamp", + ] + + +def test_metadata_presence_flags(fake_trackers: Path, fake_output_root: Path): + create_table_tracker_metadata(fake_trackers, fake_output_root) + df = pl.read_parquet(fake_output_root / "tables" / "tracker_metadata.parquet").sort("file_name") + + rows = df.to_dicts() + + a1 = next(r for r in rows if r["file_name"] == "2024_A1_Tracker") + assert a1["clinic_code"] == "CLINIC_A" + assert a1["patient_data_raw"] is True + assert a1["patient_data_cleaned"] is True + assert a1["product_data_raw"] is True + assert a1["product_data_cleaned"] is True + assert a1["complete"] is True + + a2 = next(r for r in rows if r["file_name"] == "2024_A2_Tracker") + assert a2["patient_data_raw"] is True + assert a2["product_data_raw"] is False + assert a2["complete"] is False + + b1 = next(r for r in rows if r["file_name"] == "2024_B1_Tracker") + assert b1["clinic_code"] == "CLINIC_B" + assert all(b1[s] is False for s in ( + "patient_data_raw", + "patient_data_cleaned", + "product_data_raw", + "product_data_cleaned", + )) + assert b1["complete"] is False + + +def test_metadata_md5_matches_file_bytes(fake_trackers: Path, fake_output_root: Path): + create_table_tracker_metadata(fake_trackers, fake_output_root) + df = pl.read_parquet(fake_output_root / "tables" / "tracker_metadata.parquet") + + for row in df.iter_rows(named=True): + tracker_path = fake_trackers / row["clinic_code"] / f"{row['file_name']}.xlsx" + expected = hashlib.md5(tracker_path.read_bytes(), usedforsecurity=False).hexdigest() + assert row["md5"] == expected + + +def test_metadata_handles_missing_output_subdirs(fake_trackers: Path, tmp_path: Path): + """Empty output_root (no subdirs yet) should produce all-False presence flags.""" + output_root = tmp_path / "empty_output" + output_root.mkdir() + + create_table_tracker_metadata(fake_trackers, output_root) + df = pl.read_parquet(output_root / "tables" / "tracker_metadata.parquet") + + assert df.height == 3 + for col in ("patient_data_raw", "patient_data_cleaned", "product_data_raw", "product_data_cleaned"): + assert df[col].to_list() == [False, False, False] + assert df["complete"].to_list() == [False, False, False] + + +def test_metadata_empty_data_root(tmp_path: Path): + """Empty data_root should produce a zero-row parquet with correct schema.""" + data_root = tmp_path / "data" + data_root.mkdir() + output_root = tmp_path / "output" + output_root.mkdir() + + out_path = create_table_tracker_metadata(data_root, output_root) + df = pl.read_parquet(out_path) + + assert df.height == 0 + assert df.columns == [ + "file_name", + "clinic_code", + "md5", + "patient_data_cleaned", + "patient_data_raw", + "product_data_cleaned", + "product_data_raw", + "complete", + "timestamp", + ] diff --git a/tests/test_tables/test_patient.py b/tests/test_tables/test_patient.py index 31aa932..26de6b2 100644 --- a/tests/test_tables/test_patient.py +++ b/tests/test_tables/test_patient.py @@ -209,7 +209,8 @@ def test_create_table_patient_data_static(cleaned_patient_data_files: list[Path] """Test creation of static patient data table.""" output_dir = tmp_path / "output" - output_file = create_table_patient_data_static(cleaned_patient_data_files, output_dir) + patient_data = read_cleaned_patient_data(cleaned_patient_data_files) + output_file = create_table_patient_data_static(patient_data, output_dir) assert output_file.exists() assert output_file.name == "patient_data_static.parquet" @@ -242,7 +243,8 @@ def test_create_table_patient_data_monthly(cleaned_patient_data_files: list[Path """Test creation of monthly patient data table.""" output_dir = tmp_path / "output" - output_file = create_table_patient_data_monthly(cleaned_patient_data_files, output_dir) + patient_data = read_cleaned_patient_data(cleaned_patient_data_files) + output_file = create_table_patient_data_monthly(patient_data, output_dir) assert output_file.exists() assert output_file.name == "patient_data_monthly.parquet" @@ -266,7 +268,8 @@ def test_create_table_patient_data_annual(cleaned_patient_data_files: list[Path] """Test creation of annual patient data table.""" output_dir = tmp_path / "output" - output_file = create_table_patient_data_annual(cleaned_patient_data_files, output_dir) + patient_data = read_cleaned_patient_data(cleaned_patient_data_files) + output_file = create_table_patient_data_annual(patient_data, output_dir) assert output_file.exists() assert output_file.name == "patient_data_annual.parquet" @@ -334,7 +337,8 @@ def test_create_table_patient_data_annual_filters_pre_2024(tmp_path: Path): df1.write_parquet(file1) output_dir = tmp_path / "output" - output_file = create_table_patient_data_annual([file1], output_dir) + patient_data = read_cleaned_patient_data([file1]) + output_file = create_table_patient_data_annual(patient_data, output_dir) result = pl.read_parquet(output_file) assert result.shape[0] == 0 @@ -343,7 +347,8 @@ def test_create_table_patient_data_annual_filters_pre_2024(tmp_path: Path): def test_static_table_sorting(cleaned_patient_data_files: list[Path], tmp_path: Path): """Test that static table is sorted correctly.""" output_dir = tmp_path / "output" - output_file = create_table_patient_data_static(cleaned_patient_data_files, output_dir) + patient_data = read_cleaned_patient_data(cleaned_patient_data_files) + output_file = create_table_patient_data_static(patient_data, output_dir) result = pl.read_parquet(output_file) diff --git a/tests/test_validate/__init__.py b/tests/test_validate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_validate/test_patient_source_vs_output.py b/tests/test_validate/test_patient_source_vs_output.py new file mode 100644 index 0000000..850b1fb --- /dev/null +++ b/tests/test_validate/test_patient_source_vs_output.py @@ -0,0 +1,311 @@ +"""Tests for source_vs_output_patient. + +Fixtures pass ``schema={...}`` explicitly so all-None columns keep ``pl.Utf8`` +dtype (per the polars all-None ``pl.Null`` inference trap). +""" + +from __future__ import annotations + +import polars as pl + +from a4d.errors import ErrorCollector +from a4d.validate.common import normalize_patient_id +from a4d.validate.source_vs_output_patient import ( + _join_for_cell_checks, + check_missing_patients, + check_out_of_range, + check_unexpected_nulls, + check_value_shifts, +) + +RAW_SCHEMA = { + "clinic_id": pl.Utf8, + "patient_id": pl.Utf8, + "tracker_year": pl.Int64, + "tracker_month": pl.Int64, + "file_name": pl.Utf8, + "weight": pl.Utf8, + "height": pl.Utf8, + "hba1c_updated": pl.Utf8, +} + +CLEAN_SCHEMA = { + "clinic_id": pl.Utf8, + "patient_id": pl.Utf8, + "tracker_year": pl.Int64, + "tracker_month": pl.Int64, + "file_name": pl.Utf8, + "weight": pl.Float64, + "height": pl.Float64, + "hba1c_updated": pl.Float64, +} + + +def _make_raw(rows: list[dict]) -> pl.DataFrame: + return pl.DataFrame(rows, schema=RAW_SCHEMA) + + +def _make_clean(rows: list[dict]) -> pl.DataFrame: + return pl.DataFrame(rows, schema=CLEAN_SCHEMA) + + +def test_normalize_patient_id_strips_transfer_suffix() -> None: + df = pl.DataFrame( + {"patient_id": ["MY_SM003_SB", "LA-MH093_LF", "TH_NK001", "SOLO"]}, + schema={"patient_id": pl.Utf8}, + ) + out = df.with_columns(normalize_patient_id(pl.col("patient_id")).alias("n")) + assert out["n"].to_list() == ["MY_SM003", "LA_MH093", "TH_NK001", "SOLO"] + + +def test_missing_patient_does_not_fire_for_transferred_id() -> None: + raw = _make_raw( + [ + { + "clinic_id": "CDA", + "patient_id": "MY_SM003_SB", + "tracker_year": 2024, + "tracker_month": 1, + "file_name": "f", + "weight": "70", + "height": "1.7", + "hba1c_updated": "7.0", + } + ] + ) + cleaned = _make_clean( + [ + { + "clinic_id": "CDA", + "patient_id": "MY_SM003", # cleaner stripped suffix + "tracker_year": 2024, + "tracker_month": 1, + "file_name": "f", + "weight": 70.0, + "height": 1.7, + "hba1c_updated": 7.0, + } + ] + ) + coll = ErrorCollector() + check_missing_patients(raw, cleaned, coll) + assert len(coll) == 0 + + +def test_missing_patient_fires_when_truly_absent() -> None: + raw = _make_raw( + [ + { + "clinic_id": "CDA", + "patient_id": "MY_SM003", + "tracker_year": 2024, + "tracker_month": 1, + "file_name": "f", + "weight": "70", + "height": "1.7", + "hba1c_updated": "7.0", + }, + { + "clinic_id": "CDA", + "patient_id": "MY_SM999", # absent in cleaned + "tracker_year": 2024, + "tracker_month": 1, + "file_name": "f", + "weight": "60", + "height": "1.6", + "hba1c_updated": "8.0", + }, + ] + ) + cleaned = _make_clean( + [ + { + "clinic_id": "CDA", + "patient_id": "MY_SM003", + "tracker_year": 2024, + "tracker_month": 1, + "file_name": "f", + "weight": 70.0, + "height": 1.7, + "hba1c_updated": 7.0, + } + ] + ) + coll = ErrorCollector() + check_missing_patients(raw, cleaned, coll) + msgs = [e.error_message for e in coll.errors] + assert any("MY_SM999" in m and "MISSING_ROW" in m for m in msgs) + assert not any("PHANTOM_ROW" in m for m in msgs) + + +def test_value_shift_skips_height_unit_conversion() -> None: + raw = _make_raw( + [ + { + "clinic_id": "CDA", + "patient_id": "MY_SM003", + "tracker_year": 2024, + "tracker_month": 1, + "file_name": "f", + "weight": "70", + "height": "170", # cm, cleaner divides by 100 + "hba1c_updated": "7.0", + } + ] + ) + cleaned = _make_clean( + [ + { + "clinic_id": "CDA", + "patient_id": "MY_SM003", + "tracker_year": 2024, + "tracker_month": 1, + "file_name": "f", + "weight": 70.0, + "height": 1.7, # converted from 170 cm + "hba1c_updated": 7.0, + } + ] + ) + coll = ErrorCollector() + joined = _join_for_cell_checks(raw, cleaned) + check_value_shifts(joined, coll) + # Should be zero shifts: weight matches, hba1c matches, height is skipped. + shift_msgs = [e.error_message for e in coll.errors if "VALUE_SHIFT" in e.error_message] + assert shift_msgs == [] + + +def test_value_shift_fires_on_genuine_mismatch() -> None: + raw = _make_raw( + [ + { + "clinic_id": "CDA", + "patient_id": "MY_SM003", + "tracker_year": 2024, + "tracker_month": 1, + "file_name": "f", + "weight": "70", + "height": "1.7", + "hba1c_updated": "7.0", + } + ] + ) + cleaned = _make_clean( + [ + { + "clinic_id": "CDA", + "patient_id": "MY_SM003", + "tracker_year": 2024, + "tracker_month": 1, + "file_name": "f", + "weight": 99.0, # mismatch + "height": 1.7, + "hba1c_updated": 7.0, + } + ] + ) + coll = ErrorCollector() + joined = _join_for_cell_checks(raw, cleaned) + check_value_shifts(joined, coll) + assert any("VALUE_SHIFT" in e.error_message and e.column == "weight" for e in coll.errors) + + +def test_unexpected_null_fires_when_raw_has_value_cleaned_does_not() -> None: + raw = _make_raw( + [ + { + "clinic_id": "CDA", + "patient_id": "MY_SM003", + "tracker_year": 2024, + "tracker_month": 1, + "file_name": "f", + "weight": "70", + "height": "1.7", + "hba1c_updated": "7.0", + } + ] + ) + cleaned = _make_clean( + [ + { + "clinic_id": "CDA", + "patient_id": "MY_SM003", + "tracker_year": 2024, + "tracker_month": 1, + "file_name": "f", + "weight": 70.0, + "height": 1.7, + "hba1c_updated": None, # silently dropped despite parseable raw + } + ] + ) + coll = ErrorCollector() + joined = _join_for_cell_checks(raw, cleaned) + check_unexpected_nulls(joined, coll) + msgs = [e.error_message for e in coll.errors if e.column == "hba1c_updated"] + assert msgs and "UNEXPECTED_NULL" in msgs[0] and "was_parseable=True" in msgs[0] + + +def test_out_of_range_height_skips_cm_value() -> None: + """170 cm raw should NOT fire after the cm->m auto-conversion.""" + raw = _make_raw( + [ + { + "clinic_id": "CDA", + "patient_id": "MY_SM003", + "tracker_year": 2024, + "tracker_month": 1, + "file_name": "f", + "weight": "70", + "height": "170", + "hba1c_updated": "7.0", + } + ] + ) + coll = ErrorCollector() + check_out_of_range(raw, coll) + height_findings = [e for e in coll.errors if e.column == "height"] + assert height_findings == [] + + +def test_out_of_range_height_fires_after_auto_conversion() -> None: + """height=250 (cm) -> auto-converts to 2.5 m -> still exceeds max 2.3.""" + raw = _make_raw( + [ + { + "clinic_id": "CDA", + "patient_id": "MY_SM003", + "tracker_year": 2024, + "tracker_month": 1, + "file_name": "f", + "weight": "70", + "height": "250", + "hba1c_updated": "7.0", + } + ] + ) + coll = ErrorCollector() + check_out_of_range(raw, coll) + assert any( + e.column == "height" and "OUT_OF_RANGE_RAW" in e.error_message for e in coll.errors + ) + + +def test_out_of_range_weight_fires_for_obvious_outlier() -> None: + raw = _make_raw( + [ + { + "clinic_id": "CDA", + "patient_id": "MY_SM003", + "tracker_year": 2024, + "tracker_month": 1, + "file_name": "f", + "weight": "350", # exceeds 200 max + "height": "1.7", + "hba1c_updated": "7.0", + } + ] + ) + coll = ErrorCollector() + check_out_of_range(raw, coll) + assert any(e.column == "weight" for e in coll.errors) diff --git a/tests/test_validate/test_product_source_vs_output.py b/tests/test_validate/test_product_source_vs_output.py new file mode 100644 index 0000000..f29f293 --- /dev/null +++ b/tests/test_validate/test_product_source_vs_output.py @@ -0,0 +1,248 @@ +"""Tests for source_vs_output_product.""" + +from __future__ import annotations + +import polars as pl + +from a4d.errors import ErrorCollector +from a4d.validate.source_vs_output_product import ( + _explode_multi_product_cells, + check_column_null_rate_delta, + check_missing_groups, + check_row_count_delta, +) + +RAW_SCHEMA = { + "file_name": pl.Utf8, + "product_sheet_name": pl.Utf8, + "product": pl.Utf8, + "product_units_received": pl.Utf8, + "product_units_released": pl.Utf8, +} + +CLEAN_SCHEMA = { + "file_name": pl.Utf8, + "product_sheet_name": pl.Utf8, + "product": pl.Utf8, + "product_units_received": pl.Float64, + "product_units_released": pl.Float64, + "product_balance_status": pl.Utf8, +} + + +def _raw(rows: list[dict]) -> pl.DataFrame: + return pl.DataFrame(rows, schema=RAW_SCHEMA) + + +def _clean(rows: list[dict]) -> pl.DataFrame: + return pl.DataFrame(rows, schema=CLEAN_SCHEMA) + + +def test_explode_splits_multi_product_cell() -> None: + df = _raw( + [ + { + "file_name": "f", + "product_sheet_name": "Jan24", + "product": "Insulin A; Insulin B", + "product_units_received": "10", + "product_units_released": None, + } + ] + ) + exploded = _explode_multi_product_cells(df) + assert sorted(exploded["product"].to_list()) == ["Insulin A", "Insulin B"] + + +def test_missing_group_after_explode() -> None: + raw = _raw( + [ + { + "file_name": "f", + "product_sheet_name": "Jan24", + "product": "Insulin A; Insulin B", + "product_units_received": "10", + "product_units_released": None, + } + ] + ) + cleaned = _clean( + [ + { + "file_name": "f", + "product_sheet_name": "Jan24", + "product": "Insulin A", + "product_units_received": 10.0, + "product_units_released": None, + "product_balance_status": "start", + } + # Insulin B is missing from cleaned -> should fire MISSING_GROUP + ] + ) + coll = ErrorCollector() + check_missing_groups(_explode_multi_product_cells(raw), cleaned, coll) + msgs = [e.error_message for e in coll.errors] + assert any("MISSING_GROUP" in m and "Insulin B" in m for m in msgs) + + +def test_phantom_group_fires_for_invented_product() -> None: + raw = _raw( + [ + { + "file_name": "f", + "product_sheet_name": "Jan24", + "product": "Insulin A", + "product_units_received": "10", + "product_units_released": None, + } + ] + ) + cleaned = _clean( + [ + { + "file_name": "f", + "product_sheet_name": "Jan24", + "product": "Insulin A", + "product_units_received": 10.0, + "product_units_released": None, + "product_balance_status": "start", + }, + { + "file_name": "f", + "product_sheet_name": "Jan24", + "product": "Phantom Drug", + "product_units_received": 5.0, + "product_units_released": None, + "product_balance_status": "start", + }, + ] + ) + coll = ErrorCollector() + check_missing_groups(_explode_multi_product_cells(raw), cleaned, coll) + assert any( + "PHANTOM_GROUP" in e.error_message and "Phantom Drug" in e.error_message + for e in coll.errors + ) + + +def test_row_count_delta_fires_when_counts_differ() -> None: + raw = _raw( + [ + { + "file_name": "f", + "product_sheet_name": "Jan24", + "product": "Insulin A", + "product_units_received": "10", + "product_units_released": None, + }, + { + "file_name": "f", + "product_sheet_name": "Jan24", + "product": "Insulin A", + "product_units_received": "20", + "product_units_released": None, + }, + { + "file_name": "f", + "product_sheet_name": "Jan24", + "product": "Insulin A", + "product_units_received": "30", + "product_units_released": None, + }, + ] + ) + cleaned = _clean( + [ + { + "file_name": "f", + "product_sheet_name": "Jan24", + "product": "Insulin A", + "product_units_received": 10.0, + "product_units_released": None, + "product_balance_status": "start", + } + # Cleaner dropped two rows + ] + ) + coll = ErrorCollector() + check_row_count_delta(_explode_multi_product_cells(raw), cleaned, coll) + assert any( + "ROW_COUNT_DELTA" in e.error_message and "delta=-2" in e.error_message + for e in coll.errors + ) + + +def test_row_count_delta_silent_when_counts_match() -> None: + raw = _raw( + [ + { + "file_name": "f", + "product_sheet_name": "Jan24", + "product": "Insulin A", + "product_units_received": "10", + "product_units_released": None, + } + ] + ) + cleaned = _clean( + [ + { + "file_name": "f", + "product_sheet_name": "Jan24", + "product": "Insulin A", + "product_units_received": 10.0, + "product_units_released": None, + "product_balance_status": "start", + } + ] + ) + coll = ErrorCollector() + check_row_count_delta(_explode_multi_product_cells(raw), cleaned, coll) + assert len(coll) == 0 + + +def test_column_null_rate_delta_fires_for_silent_loss() -> None: + raw = _raw( + [ + { + "file_name": "f", + "product_sheet_name": "Jan24", + "product": "A", + "product_units_received": "10", + "product_units_released": "5", + }, + { + "file_name": "f", + "product_sheet_name": "Jan24", + "product": "B", + "product_units_received": "20", + "product_units_released": "5", + }, + ] + ) + cleaned = _clean( + [ + { + "file_name": "f", + "product_sheet_name": "Jan24", + "product": "A", + "product_units_received": 10.0, + "product_units_released": None, # cleaner dropped both + "product_balance_status": "start", + }, + { + "file_name": "f", + "product_sheet_name": "Jan24", + "product": "B", + "product_units_received": 20.0, + "product_units_released": None, + "product_balance_status": "end", + }, + ] + ) + coll = ErrorCollector() + check_column_null_rate_delta(raw, cleaned, coll) + assert any( + "COLUMN_NULL_RATE_DELTA" in e.error_message and e.column == "product_units_released" + for e in coll.errors + )