@@ -141,8 +141,8 @@ def compare_metadata_fields(r_df: pl.DataFrame, py_df: pl.DataFrame):
141141
142142 # Key metadata fields that must be identical
143143 metadata_fields = [
144- "tracker_year" , "tracker_month" , "file_name " ,
145- "national_id " , "start_date " , "end_date "
144+ "tracker_year" , "tracker_month" , "tracker_date " ,
145+ "file_name " , "sheet_name " , "patient_id "
146146 ]
147147
148148 existing_fields = [f for f in metadata_fields if f in r_df .columns and f in py_df .columns ]
@@ -185,17 +185,17 @@ def compare_patient_records(r_df: pl.DataFrame, py_df: pl.DataFrame, n_samples:
185185 """Compare sample patient records in detail."""
186186 console .print (Panel (f"[bold]Sample Patient Records (first { n_samples } )[/bold]" , expand = False ))
187187
188- if "national_id " not in r_df .columns or "national_id " not in py_df .columns :
189- console .print ("[yellow]Cannot compare records: national_id column missing[/yellow]\n " )
188+ if "patient_id " not in r_df .columns or "patient_id " not in py_df .columns :
189+ console .print ("[yellow]Cannot compare records: patient_id column missing[/yellow]\n " )
190190 return
191191
192- # Get first n national_ids from R
193- sample_ids = r_df ["national_id " ].head (n_samples ).to_list ()
192+ # Get first n patient_ids from R
193+ sample_ids = r_df ["patient_id " ].head (n_samples ).to_list ()
194194
195- for idx , national_id in enumerate (sample_ids , 1 ):
196- console .print (f"\n [bold]Patient { idx } :[/bold] { national_id } " )
195+ for idx , patient_id in enumerate (sample_ids , 1 ):
196+ console .print (f"\n [bold]Patient { idx } :[/bold] { patient_id } " )
197197
198- py_records = py_df .filter (pl .col ("national_id " ) == national_id )
198+ py_records = py_df .filter (pl .col ("patient_id " ) == patient_id )
199199
200200 if len (py_records ) == 0 :
201201 console .print ("[red] ✗ Not found in Python output![/red]" )
@@ -204,12 +204,12 @@ def compare_patient_records(r_df: pl.DataFrame, py_df: pl.DataFrame, n_samples:
204204 console .print (f"[yellow] ⚠ Multiple records in Python ({ len (py_records )} )[/yellow]" )
205205
206206 # Compare key fields
207- r_record = r_df .filter (pl .col ("national_id " ) == national_id ).head (1 ).to_dicts ()[0 ]
207+ r_record = r_df .filter (pl .col ("patient_id " ) == patient_id ).head (1 ).to_dicts ()[0 ]
208208 py_record = py_records .head (1 ).to_dicts ()[0 ]
209209
210210 comparison_fields = [
211- "tracker_year" , "tracker_month" , "start_date " , "end_date " ,
212- "sex" , "age_group " , "diagnosis_malaria "
211+ "tracker_year" , "tracker_month" , "tracker_date " , "sheet_name " ,
212+ "sex" , "age " , "dob" , "status" , "province "
213213 ]
214214
215215 comp_table = Table (box = box .SIMPLE , show_header = False )
@@ -241,40 +241,85 @@ def find_value_mismatches(r_df: pl.DataFrame, py_df: pl.DataFrame):
241241 """Find all value differences for common records."""
242242 console .print (Panel ("[bold]Value Mismatches Analysis[/bold]" , expand = False ))
243243
244- if "national_id" not in r_df .columns or "national_id" not in py_df .columns :
245- console .print ("[yellow]Cannot analyze values: national_id column missing[/yellow]\n " )
244+ if "patient_id" not in r_df .columns or "patient_id" not in py_df .columns :
245+ console .print ("[yellow]Cannot analyze values: patient_id column missing[/yellow]\n " )
246+ return
247+
248+ # Join on patient_id + sheet_name to match same month records
249+ # (patients can have multiple records across different months)
250+ join_keys = ["patient_id" , "sheet_name" ]
251+ if not all (key in r_df .columns and key in py_df .columns for key in join_keys ):
252+ console .print (f"[yellow]Cannot analyze values: missing join keys { join_keys } [/yellow]\n " )
246253 return
247254
248- # Join on national_id
249255 try :
250- joined = r_df .join (py_df , on = "national_id" , how = "inner" , suffix = "_py" )
251- console .print (f"[cyan]Analyzing { len (joined ):,} common records (matched on national_id )[/cyan]\n " )
256+ joined = r_df .join (py_df , on = join_keys , how = "inner" , suffix = "_py" )
257+ console .print (f"[cyan]Analyzing { len (joined ):,} common records (matched on { '+' . join ( join_keys ) } )[/cyan]\n " )
252258 except Exception as e :
253259 console .print (f"[red]Error joining datasets: { e } [/red]\n " )
254260 return
255261
256- # Find columns in both datasets
257- common_cols = set (r_df .columns ) & set (py_df .columns ) - { "national_id" }
262+ # Find columns in both datasets (excluding join keys)
263+ common_cols = set (r_df .columns ) & set (py_df .columns ) - set ( join_keys )
258264
259265 mismatches = {}
260266
267+ # Tolerance for floating point comparisons
268+ # Use relative tolerance of 1e-9 (about 9 decimal places)
269+ FLOAT_REL_TOL = 1e-9
270+ FLOAT_ABS_TOL = 1e-12
271+
261272 for col in sorted (common_cols ):
262273 col_py = f"{ col } _py"
263274 if col in joined .columns and col_py in joined .columns :
264275 try :
265- # Count mismatches
266- mismatched_rows = joined .filter (pl .col (col ) != pl .col (col_py ))
276+ # Check if column is numeric (float or int)
277+ col_dtype = joined [col ].dtype
278+ is_numeric = col_dtype in [pl .Float32 , pl .Float64 , pl .Int8 , pl .Int16 , pl .Int32 , pl .Int64 , pl .UInt8 , pl .UInt16 , pl .UInt32 , pl .UInt64 ]
279+
280+ if is_numeric :
281+ # For numeric columns, use approximate comparison
282+ # Two values are considered equal if |a - b| <= max(rel_tol * max(|a|, |b|), abs_tol)
283+
284+ # Add columns for comparison logic
285+ comparison_df = joined .with_columns ([
286+ # Calculate absolute difference
287+ ((pl .col (col ) - pl .col (col_py )).abs ()).alias ("_abs_diff" ),
288+ # Calculate tolerance threshold
289+ pl .max_horizontal ([
290+ FLOAT_REL_TOL * pl .max_horizontal ([pl .col (col ).abs (), pl .col (col_py ).abs ()]),
291+ pl .lit (FLOAT_ABS_TOL )
292+ ]).alias ("_tolerance" ),
293+ # Check null status
294+ pl .col (col ).is_null ().alias ("_col_null" ),
295+ pl .col (col_py ).is_null ().alias ("_col_py_null" ),
296+ ])
297+
298+ # Find mismatches
299+ # Mismatch if: (1) null status differs OR (2) both not null and differ by more than tolerance
300+ mismatched_rows = comparison_df .filter (
301+ (pl .col ("_col_null" ) != pl .col ("_col_py_null" )) | # Null mismatch
302+ ((~ pl .col ("_col_null" )) & (pl .col ("_abs_diff" ) > pl .col ("_tolerance" ))) # Value mismatch
303+ )
304+ else :
305+ # For non-numeric columns, use exact comparison
306+ mismatched_rows = joined .filter (pl .col (col ) != pl .col (col_py ))
307+
267308 mismatch_count = len (mismatched_rows )
268309
269310 if mismatch_count > 0 :
270311 mismatch_pct = (mismatch_count / len (joined )) * 100
312+ # Include patient_id and sheet_name in examples for debugging
313+ examples_with_ids = mismatched_rows .select (["patient_id" , "sheet_name" , col , col_py ])
271314 mismatches [col ] = {
272315 "count" : mismatch_count ,
273316 "percentage" : mismatch_pct ,
274- "examples" : mismatched_rows .select ([col , col_py ]).head (3 )
317+ "examples" : mismatched_rows .select ([col , col_py ]).head (3 ),
318+ "examples_with_ids" : examples_with_ids
275319 }
276- except Exception :
320+ except Exception as e :
277321 # Some columns might not support comparison
322+ console .print (f"[dim]Skipped column '{ col } ': { e } [/dim]" )
278323 pass
279324
280325 if mismatches :
@@ -286,7 +331,7 @@ def find_value_mismatches(r_df: pl.DataFrame, py_df: pl.DataFrame):
286331
287332 for col , stats in sorted (mismatches .items (), key = lambda x : x [1 ]["percentage" ], reverse = True ):
288333 # Determine priority
289- if col in ["national_id " , "tracker_year" , "tracker_month" , "start_date " , "end_date " ]:
334+ if col in ["patient_id " , "tracker_year" , "tracker_month" , "tracker_date " , "file_name" , "sheet_name " ]:
290335 priority = "[red]HIGH[/red]"
291336 elif stats ["percentage" ] > 10 :
292337 priority = "[yellow]MEDIUM[/yellow]"
@@ -302,11 +347,13 @@ def find_value_mismatches(r_df: pl.DataFrame, py_df: pl.DataFrame):
302347
303348 console .print (mismatch_table )
304349
305- # Show some examples
306- console .print ("\n [dim]Examples of mismatches (first 3 columns with highest mismatch %):[/dim]" )
307- for col , stats in list (sorted (mismatches .items (), key = lambda x : x [1 ]["percentage" ], reverse = True ))[:3 ]:
308- console .print (f"\n [bold]{ col } :[/bold]" )
309- console .print (stats ["examples" ])
350+ # Show ALL mismatched columns with patient_id and sheet_name
351+ console .print ("\n [bold]Detailed Mismatches (showing ALL errors):[/bold]" )
352+ for col , stats in sorted (mismatches .items (), key = lambda x : x [1 ]["percentage" ], reverse = True ):
353+ console .print (f"\n [bold cyan]{ col } :[/bold cyan] { stats ['count' ]} mismatches ({ stats ['percentage' ]:.1f} %)" )
354+ # Include patient_id and sheet_name in examples
355+ examples_with_ids = stats ["examples_with_ids" ]
356+ console .print (examples_with_ids )
310357
311358 else :
312359 console .print ("[green]✓ All values match for common records![/green]" )
0 commit comments