Skip to content

Commit a10ae3a

Browse files
committed
Merge branch 'dev' of https://github.com/awslabs/stickler into feature/vincilb/json-export
2 parents 8e79d30 + f4f46e7 commit a10ae3a

47 files changed

Lines changed: 1352 additions & 853 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

docs/docs/Guides/StructuredModel_compare_with_README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,8 @@ result = model1.compare_with(
366366
add_confidence_metrics=True, # Add AUROC confidence calibration metric
367367
evaluator_format=False, # Format for evaluation tools
368368
recall_with_fd=False, # Include FD in recall calculation
369-
add_derived_metrics=True # Add precision/recall/F1 metrics
369+
add_derived_metrics=True, # Add precision/recall/F1 metrics
370+
document_field_comparisons=False # Document all field-level comparisons
370371
)
371372
```
372373

docs/docs/SDK-Docs/evaluator.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,7 @@
88
::: stickler.structured_object_evaluator.bulk_structured_model_evaluator.BulkStructuredModelEvaluator
99
options:
1010
heading_level: 2
11+
12+
::: stickler.structured_object_evaluator.bulk_structured_model_evaluator.aggregate_from_comparisons
13+
options:
14+
heading_level: 2

examples/scripts/non_match_analysis_demo.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ def demonstrate_basic_evaluation(gt_order, pred_order):
110110
print("\n🔍 Basic Evaluation (No Non-Match Documentation)")
111111
print("=" * 60)
112112

113-
result = gt_order.compare_with(pred_order, evaluator_format=True, document_non_matches=False)
113+
result = gt_order.compare_with(
114+
pred_order, evaluator_format=True, document_non_matches=False
115+
)
114116

115117
print("Overall Scores:")
116118
print(f" Precision: {result['overall']['precision']:.3f}")
@@ -133,7 +135,9 @@ def demonstrate_enhanced_non_matches(gt_order, pred_order):
133135
print("\n🔍 Enhanced Non-Match Analysis")
134136
print("=" * 50)
135137

136-
result = gt_order.compare_with(pred_order, evaluator_format=True, document_non_matches=True)
138+
result = gt_order.compare_with(
139+
pred_order, evaluator_format=True, document_non_matches=True
140+
)
137141

138142
# Show non-matches
139143
non_matches = result.get("non_matches", [])
@@ -299,7 +303,9 @@ def main():
299303
demonstrate_compare_with_method(gt_order, pred_order)
300304

301305
# Analyze non-matches for practical debugging
302-
result = gt_order.compare_with(pred_order, evaluator_format=True, document_non_matches=True)
306+
result = gt_order.compare_with(
307+
pred_order, evaluator_format=True, document_non_matches=True
308+
)
303309
non_matches = result.get("non_matches", [])
304310
analyze_non_matches_for_debugging(non_matches)
305311

examples/scripts/quick_start.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,9 @@ def demo_evaluator_detailed_analysis():
198198

199199
print("Evaluating similar but not identical orders...")
200200

201-
result = gt_order.compare_with(pred_order, include_confusion_matrix=True, evaluator_format=True)
201+
result = gt_order.compare_with(
202+
pred_order, include_confusion_matrix=True, evaluator_format=True
203+
)
202204

203205
print("\n📊 Overall Metrics:")
204206
print(f" Precision: {result['overall']['precision']:.3f}")

src/stickler/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
NonMatchField,
1111
NonMatchType,
1212
StructuredModel,
13+
aggregate_from_comparisons,
1314
anls_score,
1415
compare_json,
1516
compare_structured_models,
@@ -25,4 +26,5 @@
2526
"compare_structured_models",
2627
"anls_score",
2728
"compare_json",
29+
"aggregate_from_comparisons",
2830
]

src/stickler/structured_object_evaluator/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
comparison metrics and displaying the results in a user-friendly format.
55
"""
66

7+
from .bulk_structured_model_evaluator import aggregate_from_comparisons
78
from .models.comparable_field import ComparableField
89
from .models.non_match_field import NonMatchField, NonMatchType
910
from .models.structured_model import StructuredModel
@@ -19,6 +20,7 @@
1920
"compare_structured_models",
2021
"anls_score",
2122
"compare_json",
23+
"aggregate_from_comparisons",
2224
"ScoreNode",
2325
"construct_nested_dict",
2426
"merge_and_calculate_mean",

src/stickler/structured_object_evaluator/bulk_structured_model_evaluator.py

Lines changed: 93 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class BulkStructuredModelEvaluator:
4141

4242
def __init__(
4343
self,
44-
target_schema: Type[StructuredModel],
44+
target_schema: Optional[Type[StructuredModel]] = None,
4545
verbose: bool = False,
4646
document_non_matches: bool = True,
4747
elide_errors: bool = False,
@@ -51,7 +51,9 @@ def __init__(
5151
Initialize the stateful bulk evaluator.
5252
5353
Args:
54-
target_schema: StructuredModel class for validation and processing
54+
target_schema: Optional StructuredModel class for validation and processing.
55+
Required for update() and evaluate_dataframe(). Not required when using
56+
update_from_comparison_result() with pre-computed results.
5557
verbose: Whether to print detailed progress information
5658
document_non_matches: Whether to document detailed non-match information
5759
elide_errors: If True, skip documents with errors; if False, accumulate error metrics
@@ -66,10 +68,10 @@ def __init__(
6668
# Initialize state
6769
self.reset()
6870

71+
self._schema_name = target_schema.__name__ if target_schema else "unknown"
72+
6973
if self.verbose:
70-
print(
71-
f"Initialized BulkStructuredModelEvaluator for {target_schema.__name__}"
72-
)
74+
print(f"Initialized BulkStructuredModelEvaluator for {self._schema_name}")
7375
if self.individual_results_jsonl:
7476
print(
7577
f"Individual results will be appended to: {self.individual_results_jsonl}"
@@ -111,9 +113,8 @@ def update(
111113
"""
112114
Process a single document pair and accumulate the results in internal state.
113115
114-
This is the core method for stateful evaluation, inspired by PyTorch Lightning's
115-
training_step pattern. Each call processes one document pair and updates
116-
the internal confusion matrix counters.
116+
Runs compare_with() on the model pair, optionally writes the raw result
117+
to JSONL, then delegates accumulation to update_from_comparison_result().
117118
118119
Args:
119120
gt_model: Ground truth StructuredModel instance
@@ -124,29 +125,70 @@ def update(
124125
doc_id = f"doc_{self._processed_count}"
125126

126127
try:
127-
# Use compare_with method directly on the StructuredModel
128-
# Pass document_non_matches to achieve parity with compare_with method
129128
comparison_result = gt_model.compare_with(
130129
pred_model,
131130
include_confusion_matrix=True,
132131
document_non_matches=self.document_non_matches,
133132
)
134133

135-
# Collect non-matches if enabled
134+
# JSONL append of raw comparison result before accumulation
135+
if self.individual_results_jsonl:
136+
record = {"doc_id": doc_id, "comparison_result": comparison_result}
137+
with open(self.individual_results_jsonl, "a", encoding="utf-8") as f:
138+
f.write(json.dumps(record) + "\n")
139+
140+
self.update_from_comparison_result(comparison_result, doc_id)
141+
142+
except Exception as e:
143+
error_record = {
144+
"doc_id": doc_id,
145+
"error": str(e),
146+
"error_type": type(e).__name__,
147+
}
148+
149+
if not self.elide_errors:
150+
self._errors.append(error_record)
151+
self._confusion_matrix["overall"]["fn"] += 1
152+
153+
if self.verbose:
154+
print(f"Error processing document {doc_id}: {str(e)}")
155+
156+
def update_from_comparison_result(
157+
self,
158+
comparison_result: Dict[str, Any],
159+
doc_id: Optional[str] = None,
160+
) -> None:
161+
"""
162+
Accumulate a pre-computed compare_with() result into internal state.
163+
164+
Unlike update(), this method does not require StructuredModel instances
165+
or re-run comparisons. It accepts the raw dictionary output of
166+
StructuredModel.compare_with(include_confusion_matrix=True) and
167+
accumulates its confusion matrix.
168+
169+
Args:
170+
comparison_result: Dictionary returned by StructuredModel.compare_with()
171+
with include_confusion_matrix=True. Must contain a "confusion_matrix" key.
172+
doc_id: Optional document identifier for error tracking
173+
"""
174+
if doc_id is None:
175+
doc_id = f"doc_{self._processed_count}"
176+
177+
try:
178+
if "confusion_matrix" not in comparison_result:
179+
raise ValueError(
180+
"comparison_result must contain a 'confusion_matrix' key. "
181+
"Ensure compare_with() was called with include_confusion_matrix=True."
182+
)
183+
184+
# Collect non-matches if enabled and present
136185
if self.document_non_matches and "non_matches" in comparison_result:
137-
# Add doc_id to each non-match for bulk tracking
138186
for non_match in comparison_result["non_matches"]:
139187
non_match_with_doc = non_match.copy()
140188
non_match_with_doc["doc_id"] = doc_id
141189
self._non_matches.append(non_match_with_doc)
142190

143-
# Simple JSONL append of raw comparison result (before any processing)
144-
if self.individual_results_jsonl:
145-
record = {"doc_id": doc_id, "comparison_result": comparison_result}
146-
with open(self.individual_results_jsonl, "a", encoding="utf-8") as f:
147-
f.write(json.dumps(record) + "\n")
148-
149-
# Accumulate the results into our state (this flattens for aggregation)
191+
# Accumulate the confusion matrix
150192
self._accumulate_confusion_matrix(comparison_result["confusion_matrix"])
151193

152194
self._processed_count += 1
@@ -164,9 +206,6 @@ def update(
164206

165207
if not self.elide_errors:
166208
self._errors.append(error_record)
167-
168-
# For errors, add a "failed" classification to overall metrics
169-
# This represents complete failure to process the document
170209
self._confusion_matrix["overall"]["fn"] += 1
171210

172211
if self.verbose:
@@ -454,7 +493,7 @@ def save_metrics(self, filepath: str) -> None:
454493
"error_rate": len(process_eval.errors) / self._processed_count
455494
if self._processed_count > 0
456495
else 0,
457-
"target_schema": self.target_schema.__name__,
496+
"target_schema": self._schema_name,
458497
},
459498
"errors": process_eval.errors,
460499
"metadata": {
@@ -491,7 +530,7 @@ def pretty_print_metrics(self) -> None:
491530

492531
# Header
493532
print("\n" + "=" * 80)
494-
print(f"BULK EVALUATION RESULTS - {self.target_schema.__name__}")
533+
print(f"BULK EVALUATION RESULTS - {self._schema_name}")
495534
print("=" * 80)
496535

497536
# Overall metrics
@@ -575,7 +614,7 @@ def pretty_print_metrics(self) -> None:
575614
# Configuration info
576615
print("\nCONFIGURATION:")
577616
print("-" * 40)
578-
print(f"Target Schema: {self.target_schema.__name__}")
617+
print(f"Target Schema: {self._schema_name}")
579618
print(f"Document Non-matches: {'Yes' if self.document_non_matches else 'No'}")
580619
print(f"Elide Errors: {'Yes' if self.elide_errors else 'No'}")
581620
if self.individual_results_jsonl:
@@ -606,7 +645,7 @@ def get_state(self) -> Dict[str, Any]:
606645
"processed_count": self._processed_count,
607646
"start_time": self._start_time,
608647
# Configuration
609-
"target_schema": self.target_schema.__name__,
648+
"target_schema": self._schema_name,
610649
"elide_errors": self.elide_errors,
611650
}
612651

@@ -621,9 +660,9 @@ def load_state(self, state: Dict[str, Any]) -> None:
621660
state: State dictionary from get_state()
622661
"""
623662
# Validate state compatibility
624-
if state.get("target_schema") != self.target_schema.__name__:
663+
if state.get("target_schema") != self._schema_name:
625664
raise ValueError(
626-
f"State schema {state.get('target_schema')} doesn't match evaluator schema {self.target_schema.__name__}"
665+
f"State schema {state.get('target_schema')} doesn't match evaluator schema {self._schema_name}"
627666
)
628667

629668
# Restore confusion matrix state
@@ -658,9 +697,9 @@ def merge_state(self, other_state: Dict[str, Any]) -> None:
658697
other_state: State dictionary from another evaluator instance
659698
"""
660699
# Validate compatibility
661-
if other_state.get("target_schema") != self.target_schema.__name__:
700+
if other_state.get("target_schema") != self._schema_name:
662701
raise ValueError(
663-
f"Cannot merge incompatible schemas: {other_state.get('target_schema')} vs {self.target_schema.__name__}"
702+
f"Cannot merge incompatible schemas: {other_state.get('target_schema')} vs {self._schema_name}"
664703
)
665704

666705
# Merge overall metrics
@@ -722,3 +761,27 @@ def evaluate_dataframe(self, df) -> ProcessEvaluation:
722761
continue
723762

724763
return self.compute()
764+
765+
766+
def aggregate_from_comparisons(
767+
comparison_results: List[Dict[str, Any]],
768+
) -> ProcessEvaluation:
769+
"""
770+
Aggregate a list of pre-computed compare_with() results into field-level metrics.
771+
772+
This is a convenience function for aggregating stored comparison results
773+
without needing the original StructuredModel instances. It accepts the raw
774+
dictionary outputs of StructuredModel.compare_with(include_confusion_matrix=True).
775+
776+
Args:
777+
comparison_results: List of dictionaries, each returned by
778+
StructuredModel.compare_with(include_confusion_matrix=True).
779+
780+
Returns:
781+
ProcessEvaluation with aggregated metrics including overall and
782+
per-field precision, recall, F1, and accuracy.
783+
"""
784+
evaluator = BulkStructuredModelEvaluator()
785+
for result in comparison_results:
786+
evaluator.update_from_comparison_result(result)
787+
return evaluator.compute()

src/stickler/structured_object_evaluator/models/comparison_dispatcher.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from .null_helper import NullHelper
1010
from .result_helper import ResultHelper
11+
from .null_helper import NullHelper
1112

1213
if TYPE_CHECKING:
1314
from .structured_model import StructuredModel

0 commit comments

Comments
 (0)