Skip to content

Commit 7cc0764

Browse files
committed
format
1 parent 59507bb commit 7cc0764

29 files changed

Lines changed: 873 additions & 632 deletions

a4d-python/scripts/check_sheets.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ def check_sheets():
99
"""Compare which sheets were processed."""
1010

1111
r_file = Path("output/patient_data_raw/R/2024_Sibu Hospital A4D Tracker_patient_raw.parquet")
12-
python_file = Path("output/patient_data_raw/Python/2024_Sibu Hospital A4D Tracker_patient_raw.parquet")
12+
python_file = Path(
13+
"output/patient_data_raw/Python/2024_Sibu Hospital A4D Tracker_patient_raw.parquet"
14+
)
1315

1416
df_r = pl.read_parquet(r_file)
1517
df_python = pl.read_parquet(python_file)

a4d-python/scripts/compare_r_vs_python.py

Lines changed: 95 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222

2323
# Fixed base directories for R and Python outputs
2424
R_OUTPUT_BASE = Path("/Volumes/USB SanDisk 3.2Gen1 Media/a4d/output_r/patient_data_cleaned")
25-
PYTHON_OUTPUT_BASE = Path("/Volumes/USB SanDisk 3.2Gen1 Media/a4d/output_python/patient_data_cleaned")
25+
PYTHON_OUTPUT_BASE = Path(
26+
"/Volumes/USB SanDisk 3.2Gen1 Media/a4d/output_python/patient_data_cleaned"
27+
)
2628

2729

2830
def display_basic_stats(r_df: pl.DataFrame, py_df: pl.DataFrame, file_name: str):
@@ -46,7 +48,7 @@ def display_basic_stats(r_df: pl.DataFrame, py_df: pl.DataFrame, file_name: str)
4648
"Records",
4749
f"{r_count:,}",
4850
f"{py_count:,}",
49-
f"[{diff_style}]{diff_count:+,} ({diff_pct:+.1f}%)[/{diff_style}]"
51+
f"[{diff_style}]{diff_count:+,} ({diff_pct:+.1f}%)[/{diff_style}]",
5052
)
5153

5254
# Column counts
@@ -56,10 +58,7 @@ def display_basic_stats(r_df: pl.DataFrame, py_df: pl.DataFrame, file_name: str)
5658
col_style = "green" if col_diff == 0 else "yellow"
5759

5860
stats_table.add_row(
59-
"Columns",
60-
f"{r_cols:,}",
61-
f"{py_cols:,}",
62-
f"[{col_style}]{col_diff:+,}[/{col_style}]"
61+
"Columns", f"{r_cols:,}", f"{py_cols:,}", f"[{col_style}]{col_diff:+,}[/{col_style}]"
6362
)
6463

6564
console.print(stats_table)
@@ -144,8 +143,12 @@ def compare_metadata_fields(r_df: pl.DataFrame, py_df: pl.DataFrame):
144143

145144
# Key metadata fields that must be identical
146145
metadata_fields = [
147-
"tracker_year", "tracker_month", "tracker_date",
148-
"file_name", "sheet_name", "patient_id"
146+
"tracker_year",
147+
"tracker_month",
148+
"tracker_date",
149+
"file_name",
150+
"sheet_name",
151+
"patient_id",
149152
]
150153

151154
existing_fields = [f for f in metadata_fields if f in r_df.columns and f in py_df.columns]
@@ -211,8 +214,15 @@ def compare_patient_records(r_df: pl.DataFrame, py_df: pl.DataFrame, n_samples:
211214
py_record = py_records.head(1).to_dicts()[0]
212215

213216
comparison_fields = [
214-
"tracker_year", "tracker_month", "tracker_date", "sheet_name",
215-
"sex", "age", "dob", "status", "province"
217+
"tracker_year",
218+
"tracker_month",
219+
"tracker_date",
220+
"sheet_name",
221+
"sex",
222+
"age",
223+
"dob",
224+
"status",
225+
"province",
216226
]
217227

218228
comp_table = Table(box=box.SIMPLE, show_header=False)
@@ -232,7 +242,7 @@ def compare_patient_records(r_df: pl.DataFrame, py_df: pl.DataFrame, n_samples:
232242
field,
233243
str(r_val)[:25],
234244
str(py_val)[:25],
235-
f"[{match_style}]{match}[/{match_style}]"
245+
f"[{match_style}]{match}[/{match_style}]",
236246
)
237247

238248
console.print(comp_table)
@@ -257,7 +267,9 @@ def find_value_mismatches(r_df: pl.DataFrame, py_df: pl.DataFrame):
257267

258268
try:
259269
joined = r_df.join(py_df, on=join_keys, how="inner", suffix="_py")
260-
console.print(f"[cyan]Analyzing {len(joined):,} common records (matched on {'+'.join(join_keys)})[/cyan]\n")
270+
console.print(
271+
f"[cyan]Analyzing {len(joined):,} common records (matched on {'+'.join(join_keys)})[/cyan]\n"
272+
)
261273
except Exception as e:
262274
console.print(f"[red]Error joining datasets: {e}[/red]\n")
263275
return
@@ -278,31 +290,49 @@ def find_value_mismatches(r_df: pl.DataFrame, py_df: pl.DataFrame):
278290
try:
279291
# Check if column is numeric (float or int)
280292
col_dtype = joined[col].dtype
281-
is_numeric = col_dtype in [pl.Float32, pl.Float64, pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64]
293+
is_numeric = col_dtype in [
294+
pl.Float32,
295+
pl.Float64,
296+
pl.Int8,
297+
pl.Int16,
298+
pl.Int32,
299+
pl.Int64,
300+
pl.UInt8,
301+
pl.UInt16,
302+
pl.UInt32,
303+
pl.UInt64,
304+
]
282305

283306
if is_numeric:
284307
# For numeric columns, use approximate comparison
285308
# Two values are considered equal if |a - b| <= max(rel_tol * max(|a|, |b|), abs_tol)
286309

287310
# Add columns for comparison logic
288-
comparison_df = joined.with_columns([
289-
# Calculate absolute difference
290-
((pl.col(col) - pl.col(col_py)).abs()).alias("_abs_diff"),
291-
# Calculate tolerance threshold
292-
pl.max_horizontal([
293-
FLOAT_REL_TOL * pl.max_horizontal([pl.col(col).abs(), pl.col(col_py).abs()]),
294-
pl.lit(FLOAT_ABS_TOL)
295-
]).alias("_tolerance"),
296-
# Check null status
297-
pl.col(col).is_null().alias("_col_null"),
298-
pl.col(col_py).is_null().alias("_col_py_null"),
299-
])
311+
comparison_df = joined.with_columns(
312+
[
313+
# Calculate absolute difference
314+
((pl.col(col) - pl.col(col_py)).abs()).alias("_abs_diff"),
315+
# Calculate tolerance threshold
316+
pl.max_horizontal(
317+
[
318+
FLOAT_REL_TOL
319+
* pl.max_horizontal([pl.col(col).abs(), pl.col(col_py).abs()]),
320+
pl.lit(FLOAT_ABS_TOL),
321+
]
322+
).alias("_tolerance"),
323+
# Check null status
324+
pl.col(col).is_null().alias("_col_null"),
325+
pl.col(col_py).is_null().alias("_col_py_null"),
326+
]
327+
)
300328

301329
# Find mismatches
302330
# Mismatch if: (1) null status differs OR (2) both not null and differ by more than tolerance
303331
mismatched_rows = comparison_df.filter(
304-
(pl.col("_col_null") != pl.col("_col_py_null")) | # Null mismatch
305-
((~pl.col("_col_null")) & (pl.col("_abs_diff") > pl.col("_tolerance"))) # Value mismatch
332+
(pl.col("_col_null") != pl.col("_col_py_null")) # Null mismatch
333+
| (
334+
(~pl.col("_col_null")) & (pl.col("_abs_diff") > pl.col("_tolerance"))
335+
) # Value mismatch
306336
)
307337
else:
308338
# For non-numeric columns, use exact comparison
@@ -313,12 +343,14 @@ def find_value_mismatches(r_df: pl.DataFrame, py_df: pl.DataFrame):
313343
if mismatch_count > 0:
314344
mismatch_pct = (mismatch_count / len(joined)) * 100
315345
# Include patient_id and sheet_name in examples for debugging
316-
examples_with_ids = mismatched_rows.select(["patient_id", "sheet_name", col, col_py])
346+
examples_with_ids = mismatched_rows.select(
347+
["patient_id", "sheet_name", col, col_py]
348+
)
317349
mismatches[col] = {
318350
"count": mismatch_count,
319351
"percentage": mismatch_pct,
320352
"examples": mismatched_rows.select([col, col_py]).head(3),
321-
"examples_with_ids": examples_with_ids
353+
"examples_with_ids": examples_with_ids,
322354
}
323355
except Exception as e:
324356
# Some columns might not support comparison
@@ -332,28 +364,38 @@ def find_value_mismatches(r_df: pl.DataFrame, py_df: pl.DataFrame):
332364
mismatch_table.add_column("%", justify="right")
333365
mismatch_table.add_column("Priority", justify="center")
334366

335-
for col, stats in sorted(mismatches.items(), key=lambda x: x[1]["percentage"], reverse=True):
367+
for col, stats in sorted(
368+
mismatches.items(), key=lambda x: x[1]["percentage"], reverse=True
369+
):
336370
# Determine priority
337-
if col in ["patient_id", "tracker_year", "tracker_month", "tracker_date", "file_name", "sheet_name"]:
371+
if col in [
372+
"patient_id",
373+
"tracker_year",
374+
"tracker_month",
375+
"tracker_date",
376+
"file_name",
377+
"sheet_name",
378+
]:
338379
priority = "[red]HIGH[/red]"
339380
elif stats["percentage"] > 10:
340381
priority = "[yellow]MEDIUM[/yellow]"
341382
else:
342383
priority = "[dim]LOW[/dim]"
343384

344385
mismatch_table.add_row(
345-
col,
346-
f"{stats['count']:,}",
347-
f"{stats['percentage']:.1f}%",
348-
priority
386+
col, f"{stats['count']:,}", f"{stats['percentage']:.1f}%", priority
349387
)
350388

351389
console.print(mismatch_table)
352390

353391
# Show ALL mismatched columns with patient_id and sheet_name
354392
console.print("\n[bold]Detailed Mismatches (showing ALL errors):[/bold]")
355-
for col, stats in sorted(mismatches.items(), key=lambda x: x[1]["percentage"], reverse=True):
356-
console.print(f"\n[bold cyan]{col}:[/bold cyan] {stats['count']} mismatches ({stats['percentage']:.1f}%)")
393+
for col, stats in sorted(
394+
mismatches.items(), key=lambda x: x[1]["percentage"], reverse=True
395+
):
396+
console.print(
397+
f"\n[bold cyan]{col}:[/bold cyan] {stats['count']} mismatches ({stats['percentage']:.1f}%)"
398+
)
357399
# Include patient_id and sheet_name in examples
358400
examples_with_ids = stats["examples_with_ids"]
359401
console.print(examples_with_ids)
@@ -383,12 +425,20 @@ def display_summary(r_df: pl.DataFrame, py_df: pl.DataFrame):
383425

384426
# Record counts
385427
record_icon = "[green]✓[/green]" if record_match else "[red]✗[/red]"
386-
record_detail = f"Both have {r_count:,} records" if record_match else f"R: {r_count:,}, Python: {py_count:,}"
428+
record_detail = (
429+
f"Both have {r_count:,} records"
430+
if record_match
431+
else f"R: {r_count:,}, Python: {py_count:,}"
432+
)
387433
summary_table.add_row("Record counts", record_icon, record_detail)
388434

389435
# Schema
390436
schema_icon = "[green]✓[/green]" if schema_match else "[yellow]⚠[/yellow]"
391-
schema_detail = f"Both have {len(r_cols)} columns" if schema_match else f"R: {len(r_cols)}, Python: {len(py_cols)}"
437+
schema_detail = (
438+
f"Both have {len(r_cols)} columns"
439+
if schema_match
440+
else f"R: {len(r_cols)}, Python: {len(py_cols)}"
441+
)
392442
summary_table.add_row("Schema match", schema_icon, schema_detail)
393443

394444
console.print(summary_table)
@@ -414,7 +464,12 @@ def display_summary(r_df: pl.DataFrame, py_df: pl.DataFrame):
414464

415465
@app.command()
416466
def compare(
417-
file_name: str = typer.Option(..., "--file", "-f", help="Parquet filename (e.g., '2018_CDA A4D Tracker_patient_cleaned.parquet')"),
467+
file_name: str = typer.Option(
468+
...,
469+
"--file",
470+
"-f",
471+
help="Parquet filename (e.g., '2018_CDA A4D Tracker_patient_cleaned.parquet')",
472+
),
418473
):
419474
"""Compare R vs Python cleaned patient data outputs.
420475

a4d-python/scripts/reprocess_tracker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from pathlib import Path
55
from a4d.pipeline.tracker import process_tracker_patient
66

7-
tracker_file = Path("/Volumes/USB SanDisk 3.2Gen1 Media/a4d/a4dphase2_upload/Cambodia/CDA/2025_06_CDA A4D Tracker.xlsx")
7+
tracker_file = Path(
8+
"/Volumes/USB SanDisk 3.2Gen1 Media/a4d/a4dphase2_upload/Cambodia/CDA/2025_06_CDA A4D Tracker.xlsx"
9+
)
810
output_root = Path("/Volumes/USB SanDisk 3.2Gen1 Media/a4d/output_python")
911

1012
result = process_tracker_patient(tracker_file, output_root)

a4d-python/scripts/test_cleaning.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ def test_cleaning():
1212
"""Test cleaning on real tracker data."""
1313

1414
# Read the raw parquet we generated in Phase 2
15-
raw_path = Path("output/patient_data_raw/Python/2024_Sibu Hospital A4D Tracker_patient_raw.parquet")
15+
raw_path = Path(
16+
"output/patient_data_raw/Python/2024_Sibu Hospital A4D Tracker_patient_raw.parquet"
17+
)
1618

1719
if not raw_path.exists():
1820
print(f"❌ Raw parquet not found: {raw_path}")

0 commit comments

Comments
 (0)