Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions nuh_helper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
- **Dataset profiling**: profile a dataset into a Scan Report
"""

from nuh_helper.date_shift import (
import logging

logging.getLogger(__name__).addHandler(logging.NullHandler())

from nuh_helper.date_shift import ( # noqa: E402
apply_date_shifts,
generate_shift_mappings,
load_shift_mappings,
shift_excel_dates,
)
from nuh_helper.profile import generate_scan_report
from nuh_helper.profile import generate_scan_report # noqa: E402

__all__ = [
"shift_excel_dates",
Expand Down
42 changes: 38 additions & 4 deletions nuh_helper/date_shift/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
in an Excel file, with support for reproducible shifts using a linking table.
"""

import logging
import random
from datetime import date, datetime
from pathlib import Path
from typing import Any, cast

import pandas as pd

logger = logging.getLogger(__name__)


def generate_shift_mappings(
patient_ids: list[str],
Expand Down Expand Up @@ -52,6 +55,7 @@ def load_shift_mappings(csv_path: str) -> pd.DataFrame:
df = pd.read_csv(csv_path)
if "patient_id" not in df.columns or "shift_days" not in df.columns:
raise ValueError("CSV must contain 'patient_id' and 'shift_days' columns")
logger.info("Loaded %d shift mapping(s) from '%s'", len(df), csv_path)
return df


Expand Down Expand Up @@ -139,10 +143,21 @@ def apply_date_shifts(

for date_col in date_columns:
if date_col not in df.columns:
logger.warning(
"Date column '%s' not found in DataFrame, skipping", date_col
)
continue

# Parse flexible date strings (handles YYYY-DD-MM and placeholders "Unknown")
non_null_before = df[date_col].notna().sum()
df[date_col] = df[date_col].apply(_parse_date_value)
parse_failures = non_null_before - sum(x is not None for x in df[date_col])
if parse_failures > 0:
logger.debug(
"Column '%s': %d value(s) could not be parsed as dates",
date_col,
parse_failures,
)

# Apply shifts
df[date_col] = df.apply(
Expand Down Expand Up @@ -206,6 +221,10 @@ def shift_excel_dates(
If None, Excel's default date format is used.
Common formats: 'YYYY-MM-DD', 'MM/DD/YYYY', 'DD-MM-YYYY', etc.
""" # noqa: E501
logger.info("Shifting dates: '%s' → '%s'", input_file, output_file)
logger.debug(
"Shift range: %d to %d days, seed=%s", min_shift_days, max_shift_days, seed
)

def _read_sheet_with_structure(
excel_file: pd.ExcelFile,
Expand Down Expand Up @@ -341,21 +360,30 @@ def _write_sheet_with_structure(
.unique()
.tolist()
)
logger.info(
"Found %d patient(s) in sheet '%s'", len(patient_ids), patient_sheet
)

# Generate or load shift mappings
if linking_table_path and Path(linking_table_path).exists():
logger.info("Loading shift mappings from '%s'", linking_table_path)
shift_mappings = load_shift_mappings(linking_table_path)
# Filter to only include patient IDs that exist in the data
shift_mappings = shift_mappings[shift_mappings["patient_id"].isin(patient_ids)]
# Add any missing patient IDs with random shifts
existing_ids = set(shift_mappings["patient_id"])
missing_ids = [pid for pid in patient_ids if pid not in existing_ids]
if missing_ids:
logger.warning(
"%d patient(s) had no entry in the linking table; new shifts generated",
len(missing_ids),
)
new_shifts = generate_shift_mappings(
missing_ids, min_shift_days, max_shift_days, seed
)
shift_mappings = pd.concat([shift_mappings, new_shifts], ignore_index=True)
else:
logger.info("Generating shift mappings for %d patient(s)", len(patient_ids))
shift_mappings = generate_shift_mappings(
patient_ids, min_shift_days, max_shift_days, seed
)
Expand Down Expand Up @@ -388,6 +416,11 @@ def _write_sheet_with_structure(
date_columns: list[str] = cast(list[str], config["date_columns"])
header_row = cast(int, config.get("header_row", header_row))
sheet_date_columns = date_columns
logger.info(
"Shifting %d date column(s) in sheet '%s'",
len(date_columns),
sheet_name,
)

# Read sheet preserving structure
df, description_df, description_rows = _read_sheet_with_structure(
Expand Down Expand Up @@ -421,11 +454,12 @@ def _write_sheet_with_structure(
date_format=date_format,
)

logger.info("Output written to '%s'", output_file)

# Save linking table
if linking_table_output:
shift_mappings.to_csv(linking_table_output, index=False)
else:
shift_mappings.to_csv("shift_mappings.csv", index=False)
linking_path = linking_table_output or "shift_mappings.csv"
shift_mappings.to_csv(linking_path, index=False)
logger.info("Linking table saved to '%s'", linking_path)


__all__ = [
Expand Down
7 changes: 7 additions & 0 deletions nuh_helper/profile/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import csv
import logging
from collections import Counter, defaultdict
from datetime import datetime
from pathlib import Path

from openpyxl import Workbook

logger = logging.getLogger(__name__)

SCAN_REPORT_FILE_NAME = "ScanReport.xlsx"

FIELD_OVERVIEW_HEADERS = [
Expand Down Expand Up @@ -71,11 +74,14 @@ def generate_scan_report(
output_path: str = SCAN_REPORT_FILE_NAME,
min_cell_count: int = 1,
) -> str:
logger.info("Generating scan report for %d table(s)", len(csv_files))

tables = []

for csv_file in csv_files:
csv_file = Path(csv_file)
header = read_csv_header(csv_file.as_posix())
logger.info("Scanning '%s' (%d field(s))", csv_file.name, len(header))
tables.append(
{"name": csv_file.name, "path": csv_file.as_posix(), "fields": header}
)
Expand Down Expand Up @@ -151,4 +157,5 @@ def generate_scan_report(
meta_sheet.append(["minCellCount", min_cell_count])

wb.save(output_path)
logger.info("Scan report written to '%s'", output_path)
return output_path
Loading