@@ -41,7 +41,7 @@ class BulkStructuredModelEvaluator:
4141
4242 def __init__ (
4343 self ,
44- target_schema : Type [StructuredModel ],
44+ target_schema : Optional [ Type [StructuredModel ]] = None ,
4545 verbose : bool = False ,
4646 document_non_matches : bool = True ,
4747 elide_errors : bool = False ,
@@ -51,7 +51,9 @@ def __init__(
5151 Initialize the stateful bulk evaluator.
5252
5353 Args:
54- target_schema: StructuredModel class for validation and processing
54+ target_schema: Optional StructuredModel class for validation and processing.
55+ Required for update() and evaluate_dataframe(). Not required when using
56+ update_from_comparison_result() with pre-computed results.
5557 verbose: Whether to print detailed progress information
5658 document_non_matches: Whether to document detailed non-match information
5759 elide_errors: If True, skip documents with errors; if False, accumulate error metrics
@@ -66,10 +68,10 @@ def __init__(
6668 # Initialize state
6769 self .reset ()
6870
71+ self ._schema_name = target_schema .__name__ if target_schema else "unknown"
72+
6973 if self .verbose :
70- print (
71- f"Initialized BulkStructuredModelEvaluator for { target_schema .__name__ } "
72- )
74+ print (f"Initialized BulkStructuredModelEvaluator for { self ._schema_name } " )
7375 if self .individual_results_jsonl :
7476 print (
7577 f"Individual results will be appended to: { self .individual_results_jsonl } "
@@ -111,9 +113,8 @@ def update(
111113 """
112114 Process a single document pair and accumulate the results in internal state.
113115
114- This is the core method for stateful evaluation, inspired by PyTorch Lightning's
115- training_step pattern. Each call processes one document pair and updates
116- the internal confusion matrix counters.
116+ Runs compare_with() on the model pair, optionally writes the raw result
117+ to JSONL, then delegates accumulation to update_from_comparison_result().
117118
118119 Args:
119120 gt_model: Ground truth StructuredModel instance
@@ -124,29 +125,70 @@ def update(
124125 doc_id = f"doc_{ self ._processed_count } "
125126
126127 try :
127- # Use compare_with method directly on the StructuredModel
128- # Pass document_non_matches to achieve parity with compare_with method
129128 comparison_result = gt_model .compare_with (
130129 pred_model ,
131130 include_confusion_matrix = True ,
132131 document_non_matches = self .document_non_matches ,
133132 )
134133
135- # Collect non-matches if enabled
134+ # JSONL append of raw comparison result before accumulation
135+ if self .individual_results_jsonl :
136+ record = {"doc_id" : doc_id , "comparison_result" : comparison_result }
137+ with open (self .individual_results_jsonl , "a" , encoding = "utf-8" ) as f :
138+ f .write (json .dumps (record ) + "\n " )
139+
140+ self .update_from_comparison_result (comparison_result , doc_id )
141+
142+ except Exception as e :
143+ error_record = {
144+ "doc_id" : doc_id ,
145+ "error" : str (e ),
146+ "error_type" : type (e ).__name__ ,
147+ }
148+
149+ if not self .elide_errors :
150+ self ._errors .append (error_record )
151+ self ._confusion_matrix ["overall" ]["fn" ] += 1
152+
153+ if self .verbose :
154+ print (f"Error processing document { doc_id } : { str (e )} " )
155+
156+ def update_from_comparison_result (
157+ self ,
158+ comparison_result : Dict [str , Any ],
159+ doc_id : Optional [str ] = None ,
160+ ) -> None :
161+ """
162+ Accumulate a pre-computed compare_with() result into internal state.
163+
164+ Unlike update(), this method does not require StructuredModel instances
165+ or re-run comparisons. It accepts the raw dictionary output of
166+ StructuredModel.compare_with(include_confusion_matrix=True) and
167+ accumulates its confusion matrix.
168+
169+ Args:
170+ comparison_result: Dictionary returned by StructuredModel.compare_with()
171+ with include_confusion_matrix=True. Must contain a "confusion_matrix" key.
172+ doc_id: Optional document identifier for error tracking
173+ """
174+ if doc_id is None :
175+ doc_id = f"doc_{ self ._processed_count } "
176+
177+ try :
178+ if "confusion_matrix" not in comparison_result :
179+ raise ValueError (
180+ "comparison_result must contain a 'confusion_matrix' key. "
181+ "Ensure compare_with() was called with include_confusion_matrix=True."
182+ )
183+
184+ # Collect non-matches if enabled and present
136185 if self .document_non_matches and "non_matches" in comparison_result :
137- # Add doc_id to each non-match for bulk tracking
138186 for non_match in comparison_result ["non_matches" ]:
139187 non_match_with_doc = non_match .copy ()
140188 non_match_with_doc ["doc_id" ] = doc_id
141189 self ._non_matches .append (non_match_with_doc )
142190
143- # Simple JSONL append of raw comparison result (before any processing)
144- if self .individual_results_jsonl :
145- record = {"doc_id" : doc_id , "comparison_result" : comparison_result }
146- with open (self .individual_results_jsonl , "a" , encoding = "utf-8" ) as f :
147- f .write (json .dumps (record ) + "\n " )
148-
149- # Accumulate the results into our state (this flattens for aggregation)
191+ # Accumulate the confusion matrix
150192 self ._accumulate_confusion_matrix (comparison_result ["confusion_matrix" ])
151193
152194 self ._processed_count += 1
@@ -164,9 +206,6 @@ def update(
164206
165207 if not self .elide_errors :
166208 self ._errors .append (error_record )
167-
168- # For errors, add a "failed" classification to overall metrics
169- # This represents complete failure to process the document
170209 self ._confusion_matrix ["overall" ]["fn" ] += 1
171210
172211 if self .verbose :
@@ -454,7 +493,7 @@ def save_metrics(self, filepath: str) -> None:
454493 "error_rate" : len (process_eval .errors ) / self ._processed_count
455494 if self ._processed_count > 0
456495 else 0 ,
457- "target_schema" : self .target_schema . __name__ ,
496+ "target_schema" : self ._schema_name ,
458497 },
459498 "errors" : process_eval .errors ,
460499 "metadata" : {
@@ -491,7 +530,7 @@ def pretty_print_metrics(self) -> None:
491530
492531 # Header
493532 print ("\n " + "=" * 80 )
494- print (f"BULK EVALUATION RESULTS - { self .target_schema . __name__ } " )
533+ print (f"BULK EVALUATION RESULTS - { self ._schema_name } " )
495534 print ("=" * 80 )
496535
497536 # Overall metrics
@@ -575,7 +614,7 @@ def pretty_print_metrics(self) -> None:
575614 # Configuration info
576615 print ("\n CONFIGURATION:" )
577616 print ("-" * 40 )
578- print (f"Target Schema: { self .target_schema . __name__ } " )
617+ print (f"Target Schema: { self ._schema_name } " )
579618 print (f"Document Non-matches: { 'Yes' if self .document_non_matches else 'No' } " )
580619 print (f"Elide Errors: { 'Yes' if self .elide_errors else 'No' } " )
581620 if self .individual_results_jsonl :
@@ -606,7 +645,7 @@ def get_state(self) -> Dict[str, Any]:
606645 "processed_count" : self ._processed_count ,
607646 "start_time" : self ._start_time ,
608647 # Configuration
609- "target_schema" : self .target_schema . __name__ ,
648+ "target_schema" : self ._schema_name ,
610649 "elide_errors" : self .elide_errors ,
611650 }
612651
@@ -621,9 +660,9 @@ def load_state(self, state: Dict[str, Any]) -> None:
621660 state: State dictionary from get_state()
622661 """
623662 # Validate state compatibility
624- if state .get ("target_schema" ) != self .target_schema . __name__ :
663+ if state .get ("target_schema" ) != self ._schema_name :
625664 raise ValueError (
626- f"State schema { state .get ('target_schema' )} doesn't match evaluator schema { self .target_schema . __name__ } "
665+ f"State schema { state .get ('target_schema' )} doesn't match evaluator schema { self ._schema_name } "
627666 )
628667
629668 # Restore confusion matrix state
@@ -658,9 +697,9 @@ def merge_state(self, other_state: Dict[str, Any]) -> None:
658697 other_state: State dictionary from another evaluator instance
659698 """
660699 # Validate compatibility
661- if other_state .get ("target_schema" ) != self .target_schema . __name__ :
700+ if other_state .get ("target_schema" ) != self ._schema_name :
662701 raise ValueError (
663- f"Cannot merge incompatible schemas: { other_state .get ('target_schema' )} vs { self .target_schema . __name__ } "
702+ f"Cannot merge incompatible schemas: { other_state .get ('target_schema' )} vs { self ._schema_name } "
664703 )
665704
666705 # Merge overall metrics
@@ -722,3 +761,27 @@ def evaluate_dataframe(self, df) -> ProcessEvaluation:
722761 continue
723762
724763 return self .compute ()
764+
765+
766+ def aggregate_from_comparisons (
767+ comparison_results : List [Dict [str , Any ]],
768+ ) -> ProcessEvaluation :
769+ """
770+ Aggregate a list of pre-computed compare_with() results into field-level metrics.
771+
772+ This is a convenience function for aggregating stored comparison results
773+ without needing the original StructuredModel instances. It accepts the raw
774+ dictionary outputs of StructuredModel.compare_with(include_confusion_matrix=True).
775+
776+ Args:
777+ comparison_results: List of dictionaries, each returned by
778+ StructuredModel.compare_with(include_confusion_matrix=True).
779+
780+ Returns:
781+ ProcessEvaluation with aggregated metrics including overall and
782+ per-field precision, recall, F1, and accuracy.
783+ """
784+ evaluator = BulkStructuredModelEvaluator ()
785+ for result in comparison_results :
786+ evaluator .update_from_comparison_result (result )
787+ return evaluator .compute ()
0 commit comments