Skip to content

Commit 75c31c7

Browse files
authored
Fixed LLM comparator and added pytests for it (awslabs#15)
Update LLM comparator to use Strands, add tests and example script
1 parent ce6f4a9 commit 75c31c7

6 files changed

Lines changed: 763 additions & 169 deletions

File tree

.github/workflows/run_pytest.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ jobs:
2222
run: |
2323
python -m pip install --upgrade pip
2424
pip install -e ".[dev]"
25+
pip install -e ".[llm]"
2526
- name: Test with pytest
2627
run: |
2728
coverage run -m pytest -v -s
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#!/usr/bin/env python3
2+
"""
3+
LLM Comparator Demo Script
4+
5+
This script demonstrates the LLMComparator functionality for semantic comparison
6+
of values using Large Language Models. The LLMComparator leverages AWS Bedrock
7+
models through the strands-agents library to perform intelligent comparisons
8+
that go beyond simple string matching.
9+
10+
Requirements:
11+
- AWS credentials configured for Bedrock access
12+
- Environment variables for model configuration (optional)
13+
"""
14+
from stickler.comparators.llm import LLMComparator
15+
from stickler.comparators.exact import ExactComparator
16+
from stickler.comparators.levenshtein import LevenshteinComparator
17+
from stickler.structured_object_evaluator.models.structured_model import StructuredModel
18+
from stickler.structured_object_evaluator.models.comparable_field import ComparableField
19+
20+
21+
def print_section_header(title: str):
22+
"""Print a formatted section header."""
23+
print(f"\n{'=' * 60}")
24+
print(f"🔍 {title}")
25+
print(f"{'=' * 60}")
26+
27+
28+
def demo_structured_model_integration():
29+
"""Demonstrate LLM comparator integration with StructuredModel."""
30+
print_section_header("STRUCTURED MODEL INTEGRATION")
31+
32+
# Define a customer model with mixed comparators
33+
class CustomerAddress(StructuredModel):
34+
street: str = ComparableField(
35+
comparator=LLMComparator(
36+
model="us.amazon.nova-lite-v1:0",
37+
eval_guidelines="Consider street abbreviations equivalent (St=Street, Ave=Avenue, etc.)"
38+
),
39+
threshold=0.8,
40+
weight=1.0
41+
)
42+
city: str = ComparableField(
43+
comparator=LevenshteinComparator(),
44+
threshold=0.9,
45+
weight=1.0
46+
)
47+
zip_code: str = ComparableField(
48+
comparator=ExactComparator(),
49+
threshold=1.0,
50+
weight=1.0
51+
)
52+
53+
class Customer(StructuredModel):
54+
name: str = ComparableField(
55+
comparator=ExactComparator(),
56+
threshold=0.8,
57+
weight=1.0
58+
)
59+
email: str = ComparableField(
60+
comparator=ExactComparator(),
61+
threshold=1.0,
62+
weight=1.0
63+
)
64+
address: CustomerAddress = ComparableField(
65+
comparator=ExactComparator(),
66+
threshold=1.0,
67+
weight=1.0
68+
)
69+
70+
print("Comparing customer records with mixed comparator types...")
71+
72+
# Ground truth customer
73+
gt_customer = Customer(
74+
name="Robert Johnson",
75+
email="robert.johnson@email.com",
76+
address=CustomerAddress(
77+
street="123 Main Street",
78+
city="Seattle",
79+
zip_code="98101"
80+
)
81+
)
82+
83+
# Predicted customer with variations
84+
pred_customer = Customer(
85+
name="Robert Johnson",
86+
email="robert.johnson@email.com",
87+
address=CustomerAddress(
88+
street="123 Main St", # Street abbreviation
89+
city="Seattle",
90+
zip_code="98101"
91+
)
92+
)
93+
94+
# Compare the customers
95+
result = gt_customer.compare_with(pred_customer, include_confusion_matrix=True)
96+
97+
# Show field-level results
98+
print("\nField-level comparison results:")
99+
cm = result['confusion_matrix']
100+
for field_name, field_data in cm['fields'].items():
101+
field_result = field_data['overall']
102+
print(f" {field_name}: {field_result}")
103+
104+
105+
def main():
106+
"""Run all demonstration functions."""
107+
print("🚀 LLM COMPARATOR COMPREHENSIVE DEMO")
108+
print("=" * 60)
109+
print("This demo showcases the LLMComparator functionality for")
110+
print("semantic comparison using Large Language Models.")
111+
112+
# Check for required environment setup
113+
print("\n📋 Environment Check:")
114+
115+
try:
116+
117+
demo_structured_model_integration()
118+
119+
print_section_header("DEMO COMPLETE")
120+
print("✅ All demonstrations completed successfully!")
121+
print("\n💡 Key Takeaways:")
122+
print(" • LLMComparator provides semantic comparison beyond string matching")
123+
print(" • Integrates seamlessly with StructuredModel for complex objects")
124+
125+
print("\n🔧 Best Practices:")
126+
print(" • Use specific guidelines for better accuracy")
127+
print(" • Choose appropriate models for your use case")
128+
print(" • Handle None values and edge cases")
129+
print(" • Monitor API costs and latency")
130+
print(" • Test with representative data")
131+
132+
except Exception as e:
133+
print(f"\n❌ Demo failed with error: {e}")
134+
print("Please check your AWS credentials and model access.")
135+
return 1
136+
137+
return 0
138+
139+
140+
if __name__ == "__main__":
141+
exit(main())

pyproject.toml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ readme = "README.md"
1212
requires-python = ">=3.12"
1313

1414
dependencies = [
15-
"pydantic>=2.11.0",
15+
"pydantic>=2.11.0,<3.0.0",
1616
"rapidfuzz>=3.0.0",
1717
"munkres>=1.1.4",
18-
"numpy>=1.24.0",
18+
"numpy>=1.24.0,<=2.3.3",
1919
"scipy>=1.10.0",
2020
"psutil>=5.8.0",
2121
"pandas>=1.5.0",
@@ -30,6 +30,11 @@ dev = [
3030
"beautifulsoup4>=4.14.2"
3131
]
3232

33+
llm = [
34+
"strands-agents>=1.0.0,<=1.16.0",
35+
"jinja2>=3.0.0,<=3.1.6"
36+
]
37+
3338

3439
[tool.setuptools]
3540
package-dir = {"" = "src"}

0 commit comments

Comments
 (0)