99library and supports customizable evaluation guidelines for domain-specific
1010comparison logic.
1111
12+ Note:
13+ This comparator requires the optional 'llm' dependencies. Install with:
14+ pip install stickler-eval[llm]
15+
1216Example:
1317 Integration with StructuredModel:
1418 >>> from stickler.structured_object_evaluator.models.comparable_field import ComparableField
15- >>>
19+ >>>
1620 >>> class Address(StructuredModel):
1721 ... street: str = ComparableField(
1822 ... comparator=LLMComparator(eval_guidelines="Consider street abbreviations"),
1923 ... threshold=0.8
2024 ... )
2125"""
26+
2227import html
2328from typing import Any , Dict , Union
2429
25- from botocore .exceptions import NoCredentialsError
2630from jinja2 import Template
27- from strands import Agent
28- from strands .models import Model
2931
3032from stickler .comparators .base import BaseComparator
3133
34+ try :
35+ from botocore .exceptions import NoCredentialsError
36+ from strands import Agent
37+ from strands .models import Model
38+
39+ STRANDS_AVAILABLE = True
40+ except ImportError :
41+ STRANDS_AVAILABLE = False
42+
43+ # Create mock classes for when strands is not available
44+ class Model :
45+ pass
46+
47+ class Agent :
48+ pass
49+
50+ class NoCredentialsError (Exception ):
51+ pass
52+
3253
3354class LLMComparator (BaseComparator ):
3455 """Large Language Model-based semantic comparator.
35-
56+
3657 This comparator uses LLMs to perform intelligent semantic comparisons that go
3758 beyond simple string matching. It can understand context, handle abbreviations,
3859 recognize synonyms, and apply domain-specific comparison logic through custom
3960 evaluation guidelines.
40-
61+
4162 The comparator returns binary similarity scores (0.0 or 1.0) based on whether
4263 the LLM determines the values are semantically equivalent. It handles edge cases
4364 like None values and provides detailed comparison information for debugging.
44-
65+
4566 Attributes:
4667 model (Union[Model, str]): The LLM model identifier or Model instance.
4768 eval_guidelines (str, optional): Custom guidelines for comparison logic.
4869 system_prompt (str): The system prompt used to instruct the LLM.
4970 prompt_template (Template): Jinja2 template for formatting comparison prompts.
5071 agent (Agent): The strands Agent instance for LLM interactions.
5172 threshold (float): Inherited from BaseComparator, used for binary decisions.
52-
73+
5374 Note:
5475 This comparator requires AWS Bedrock access and proper authentication.
5576 API calls incur costs and latency, so consider caching for repeated comparisons.
5677 """
78+
5779 def __init__ (
5880 self ,
5981 model : Union [Model , str ] = None ,
6082 eval_guidelines : str = None ,
6183 ):
6284 """Initialize the LLM comparator.
63-
85+
6486 Args:
6587 model: The LLM model to use for comparisons. Can be a model identifier
6688 string (e.g., "us.anthropic.claude-3-haiku-20240307-v1:0") or a
6789 strands Model instance. Defaults to Claude 3 Haiku.
6890 eval_guidelines: Optional custom guidelines to include in the comparison
6991 prompt. These guidelines help the LLM understand domain-specific
7092 comparison rules (e.g., "Consider abbreviations equivalent").
71-
93+
7294 Raises:
73- Exception: If the model cannot be initialized or AWS credentials are invalid.
74-
95+ ImportError: If strands-agents is not installed.
96+ ValueError: If the model parameter is not provided.
97+
7598 Example:
7699 >>> # Basic initialization
77100 >>> comparator = LLMComparator()
78-
101+
79102 >>> # With custom model and guidelines
80103 >>> comparator = LLMComparator(
81104 ... model="us.amazon.nova-lite-v1:0",
82105 ... eval_guidelines="Consider street abbreviations equivalent"
83106 ... )
84107 """
85108 super ().__init__ ()
109+
110+ # Check if strands is available
111+ if not STRANDS_AVAILABLE :
112+ raise ImportError (
113+ "LLMComparator requires the 'strands-agents' package. "
114+ "Install it with: pip install stickler-eval[llm]"
115+ )
116+
86117 if model is None :
87118 raise ValueError ("Model must be provided for LLMComparator." )
88119 self .model = model
@@ -92,25 +123,23 @@ def __init__(
92123 self .eval_guidelines = html .escape (eval_guidelines )
93124 else :
94125 self .eval_guidelines = eval_guidelines
95-
126+
96127 # Initialize Agent
97128 self .agent = Agent (
98- model = self .model ,
99- system_prompt = self .system_prompt ,
100- callback_handler = None
129+ model = self .model , system_prompt = self .system_prompt , callback_handler = None
101130 )
102131
103132 def _default_system_prompt (self ) -> str :
104133 """Generate the default system prompt for the LLM.
105-
134+
106135 Returns:
107136 str: System prompt instructing the LLM to perform binary comparisons.
108137 """
109138 return "You are a helpful assistant that compares two values and determines if they are equivalent. Only return one word: 'true' or 'false'."
110-
139+
111140 def _default_prompt_template (self ) -> Template :
112141 """Generate the default Jinja2 template for comparison prompts.
113-
142+
114143 Returns:
115144 Template: Jinja2 template that formats comparison prompts with values
116145 and optional evaluation guidelines.
@@ -133,44 +162,44 @@ def _default_prompt_template(self) -> Template:
133162
134163 template = Template (prompt_template )
135164 return template
136-
165+
137166 def _invoke_agent (self , prompt : str ) -> str :
138167 """Invoke the LLM agent with a formatted prompt.
139-
168+
140169 Args:
141170 prompt: The formatted prompt string to send to the LLM.
142-
171+
143172 Returns:
144173 str: The text response from the LLM.
145-
174+
146175 Raises:
147176 Exception: If the agent call fails or response format is unexpected.
148177 """
149178 result = self .agent (prompt )
150179 return result .message ["content" ][0 ]["text" ]
151-
180+
152181 def compare (self , value1 : Any , value2 : Any ) -> float :
153182 """Compare two values using LLM-based semantic analysis.
154-
183+
155184 This method converts both values to strings and uses the configured LLM
156185 to determine if they are semantically equivalent. The comparison considers
157186 context, abbreviations, synonyms, and any provided evaluation guidelines.
158-
187+
159188 Args:
160189 value1: First value to compare. Can be any type that converts to string.
161190 value2: Second value to compare. Can be any type that converts to string.
162-
191+
163192 Returns:
164193 float: Binary similarity score:
165194 - 1.0 if the LLM determines the values are equivalent
166195 - 0.0 if the LLM determines the values are not equivalent
167196 - 0.0 if an error occurs during comparison
168-
197+
169198 Note:
170199 - None values: Returns 1.0 if both are None, 0.0 if only one is None
171200 - Error handling: Returns 0.0 for any exceptions during LLM calls
172201 - Cost consideration: Each call incurs API costs and latency
173-
202+
174203 Example:
175204 >>> comparator = LLMComparator()
176205 >>> comparator.compare("St. John's Street", "Saint John's St")
@@ -190,51 +219,50 @@ def compare(self, value1: Any, value2: Any) -> float:
190219 formatted_prompt = self .prompt_template .render (
191220 value1 = html .escape (str (value1 )),
192221 value2 = html .escape (str (value2 )),
193- eval_guidelines = self .eval_guidelines
222+ eval_guidelines = self .eval_guidelines ,
194223 )
195-
224+
196225 try :
197226 # Get LLM response
198227 response = self ._invoke_agent (formatted_prompt )
199228 # Parse response to boolean
200229 response_lower = response .strip ().lower ()
201- if ' true' in response_lower :
230+ if " true" in response_lower :
202231 return 1.0
203232 else :
204233 return 0.0
205-
234+
206235 except NoCredentialsError :
207236 print ("Error: AWS credentials not found." )
208- raise
237+ raise
209238
210239 except Exception as e :
211240 print (f"Error during LLM call: { e } " )
212241 raise
213242
214-
215243 def get_comparison_details (self , value1 : Any , value2 : Any ) -> Dict [str , Any ]:
216244 """Get detailed information about a comparison operation.
217-
245+
218246 This method provides comprehensive details about the comparison process,
219247 including the formatted prompt, LLM response, model information, and
220248 final comparison result. Useful for debugging, auditing, and understanding
221249 how the LLM made its decision.
222-
250+
223251 Args:
224252 value1: First value to compare. Can be any type that converts to string.
225253 value2: Second value to compare. Can be any type that converts to string.
226-
254+
227255 Returns:
228256 Dict[str, Any]: Dictionary containing comparison details:
229257 - 'prompt' (str): The formatted prompt sent to the LLM
230258 - 'llm_response' (str): Raw response from the LLM
231259 - 'model_id' (Union[Model, str]): The model used (string ID or Model instance)
232260 - 'comparison_result' (float): Final similarity score (0.0 or 1.0)
233-
261+
234262 On error:
235263 - 'error' (str): Error message describing what went wrong
236264 - 'comparison_result' (bool): False to indicate failure
237-
265+
238266 Example:
239267 >>> comparator = LLMComparator(eval_guidelines="Consider abbreviations")
240268 >>> details = comparator.get_comparison_details("St. John", "Saint John")
@@ -248,19 +276,16 @@ def get_comparison_details(self, value1: Any, value2: Any) -> Dict[str, Any]:
248276 formatted_prompt = self .prompt_template .render (
249277 value1 = html .escape (str (value1 )),
250278 value2 = html .escape (str (value2 )),
251- eval_guidelines = self .eval_guidelines
279+ eval_guidelines = self .eval_guidelines ,
252280 )
253-
281+
254282 try :
255283 response = self ._invoke_agent (formatted_prompt )
256284 return {
257285 "prompt" : formatted_prompt ,
258286 "llm_response" : response ,
259287 "model_id" : self .model ,
260- "comparison_result" : self .compare (value1 , value2 )
288+ "comparison_result" : self .compare (value1 , value2 ),
261289 }
262290 except Exception as e :
263- return {
264- "error" : str (e ),
265- "comparison_result" : False
266- }
291+ return {"error" : str (e ), "comparison_result" : False }
0 commit comments