Skip to content

Commit a4faa5b

Browse files
authored
Strands and Jinja2 package updates (awslabs#60)
* Added a mock class for strands to mitigate import errors, and raise an import error for trying to use LLM * Updated jinja template package version
1 parent 420ddf5 commit a4faa5b

3 files changed

Lines changed: 88 additions & 54 deletions

File tree

pyproject.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ dependencies = [
1919
"scipy>=1.10.0",
2020
"psutil>=5.8.0",
2121
"pandas>=1.5.0",
22-
"jsonschema>=4.0.0"
22+
"jsonschema>=4.0.0",
23+
"jinja2>=3.1.6,<3.2.0"
2324
]
2425

2526
[project.optional-dependencies]
@@ -30,10 +31,8 @@ dev = [
3031
"beautifulsoup4>=4.14.2",
3132
"ruff>=0.14.10",
3233
]
33-
3434
llm = [
35-
"strands-agents>=1.0.0,<=1.16.0",
36-
"jinja2>=3.0.0,<=3.1.6"
35+
"strands-agents>=1.0.0,<=1.16.0"
3736
]
3837

3938

src/stickler/comparators/__init__.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@
1414
from stickler.comparators.structured import StructuredModelComparator
1515
from stickler.comparators.utils import generate_bedrock_embedding
1616

17+
# Import LLMComparator if strands-agents is available
18+
try:
19+
from stickler.comparators.llm import LLMComparator # noqa: F401
20+
21+
LLM_AVAILABLE = True
22+
except ImportError:
23+
LLM_AVAILABLE = False
24+
1725
# Import BERTComparator if evaluate is available
1826
try:
1927
from stickler.comparators.bert import BERTComparator # noqa: F401
@@ -39,12 +47,15 @@
3947
"NumericComparator",
4048
"NumericExactC",
4149
"ExactComparator",
42-
"LLMComparator",
4350
"StructuredModelComparator",
4451
"SemanticComparator",
4552
"generate_bedrock_embedding",
4653
]
4754

55+
# Add LLMComparator to __all__ if available
56+
if LLM_AVAILABLE:
57+
__all__.append("LLMComparator")
58+
4859
# Add BERTComparator to __all__ if available
4960
if BERT_AVAILABLE:
5061
__all__.append("BERTComparator")
@@ -53,4 +64,3 @@
5364
if RAPIDFUZZ_AVAILABLE:
5465
__all__.append("FuzzyComparator")
5566
__all__.append("Fuzz")
56-

src/stickler/comparators/llm.py

Lines changed: 73 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,80 +9,111 @@
99
library and supports customizable evaluation guidelines for domain-specific
1010
comparison logic.
1111
12+
Note:
13+
This comparator requires the optional 'llm' dependencies. Install with:
14+
pip install stickler-eval[llm]
15+
1216
Example:
1317
Integration with StructuredModel:
1418
>>> from stickler.structured_object_evaluator.models.comparable_field import ComparableField
15-
>>>
19+
>>>
1620
>>> class Address(StructuredModel):
1721
... street: str = ComparableField(
1822
... comparator=LLMComparator(eval_guidelines="Consider street abbreviations"),
1923
... threshold=0.8
2024
... )
2125
"""
26+
2227
import html
2328
from typing import Any, Dict, Union
2429

25-
from botocore.exceptions import NoCredentialsError
2630
from jinja2 import Template
27-
from strands import Agent
28-
from strands.models import Model
2931

3032
from stickler.comparators.base import BaseComparator
3133

34+
try:
35+
from botocore.exceptions import NoCredentialsError
36+
from strands import Agent
37+
from strands.models import Model
38+
39+
STRANDS_AVAILABLE = True
40+
except ImportError:
41+
STRANDS_AVAILABLE = False
42+
43+
# Create mock classes for when strands is not available
44+
class Model:
45+
pass
46+
47+
class Agent:
48+
pass
49+
50+
class NoCredentialsError(Exception):
51+
pass
52+
3253

3354
class LLMComparator(BaseComparator):
3455
"""Large Language Model-based semantic comparator.
35-
56+
3657
This comparator uses LLMs to perform intelligent semantic comparisons that go
3758
beyond simple string matching. It can understand context, handle abbreviations,
3859
recognize synonyms, and apply domain-specific comparison logic through custom
3960
evaluation guidelines.
40-
61+
4162
The comparator returns binary similarity scores (0.0 or 1.0) based on whether
4263
the LLM determines the values are semantically equivalent. It handles edge cases
4364
like None values and provides detailed comparison information for debugging.
44-
65+
4566
Attributes:
4667
model (Union[Model, str]): The LLM model identifier or Model instance.
4768
eval_guidelines (str, optional): Custom guidelines for comparison logic.
4869
system_prompt (str): The system prompt used to instruct the LLM.
4970
prompt_template (Template): Jinja2 template for formatting comparison prompts.
5071
agent (Agent): The strands Agent instance for LLM interactions.
5172
threshold (float): Inherited from BaseComparator, used for binary decisions.
52-
73+
5374
Note:
5475
This comparator requires AWS Bedrock access and proper authentication.
5576
API calls incur costs and latency, so consider caching for repeated comparisons.
5677
"""
78+
5779
def __init__(
5880
self,
5981
model: Union[Model, str] = None,
6082
eval_guidelines: str = None,
6183
):
6284
"""Initialize the LLM comparator.
63-
85+
6486
Args:
6587
model: The LLM model to use for comparisons. Can be a model identifier
6688
string (e.g., "us.anthropic.claude-3-haiku-20240307-v1:0") or a
6789
strands Model instance. Defaults to Claude 3 Haiku.
6890
eval_guidelines: Optional custom guidelines to include in the comparison
6991
prompt. These guidelines help the LLM understand domain-specific
7092
comparison rules (e.g., "Consider abbreviations equivalent").
71-
93+
7294
Raises:
73-
Exception: If the model cannot be initialized or AWS credentials are invalid.
74-
95+
ImportError: If strands-agents is not installed.
96+
ValueError: If the model parameter is not provided.
97+
7598
Example:
7699
>>> # Basic initialization
77100
>>> comparator = LLMComparator()
78-
101+
79102
>>> # With custom model and guidelines
80103
>>> comparator = LLMComparator(
81104
... model="us.amazon.nova-lite-v1:0",
82105
... eval_guidelines="Consider street abbreviations equivalent"
83106
... )
84107
"""
85108
super().__init__()
109+
110+
# Check if strands is available
111+
if not STRANDS_AVAILABLE:
112+
raise ImportError(
113+
"LLMComparator requires the 'strands-agents' package. "
114+
"Install it with: pip install stickler-eval[llm]"
115+
)
116+
86117
if model is None:
87118
raise ValueError("Model must be provided for LLMComparator.")
88119
self.model = model
@@ -92,25 +123,23 @@ def __init__(
92123
self.eval_guidelines = html.escape(eval_guidelines)
93124
else:
94125
self.eval_guidelines = eval_guidelines
95-
126+
96127
# Initialize Agent
97128
self.agent = Agent(
98-
model=self.model,
99-
system_prompt=self.system_prompt,
100-
callback_handler=None
129+
model=self.model, system_prompt=self.system_prompt, callback_handler=None
101130
)
102131

103132
def _default_system_prompt(self) -> str:
104133
"""Generate the default system prompt for the LLM.
105-
134+
106135
Returns:
107136
str: System prompt instructing the LLM to perform binary comparisons.
108137
"""
109138
return "You are a helpful assistant that compares two values and determines if they are equivalent. Only return one word: 'true' or 'false'."
110-
139+
111140
def _default_prompt_template(self) -> Template:
112141
"""Generate the default Jinja2 template for comparison prompts.
113-
142+
114143
Returns:
115144
Template: Jinja2 template that formats comparison prompts with values
116145
and optional evaluation guidelines.
@@ -133,44 +162,44 @@ def _default_prompt_template(self) -> Template:
133162

134163
template = Template(prompt_template)
135164
return template
136-
165+
137166
def _invoke_agent(self, prompt: str) -> str:
138167
"""Invoke the LLM agent with a formatted prompt.
139-
168+
140169
Args:
141170
prompt: The formatted prompt string to send to the LLM.
142-
171+
143172
Returns:
144173
str: The text response from the LLM.
145-
174+
146175
Raises:
147176
Exception: If the agent call fails or response format is unexpected.
148177
"""
149178
result = self.agent(prompt)
150179
return result.message["content"][0]["text"]
151-
180+
152181
def compare(self, value1: Any, value2: Any) -> float:
153182
"""Compare two values using LLM-based semantic analysis.
154-
183+
155184
This method converts both values to strings and uses the configured LLM
156185
to determine if they are semantically equivalent. The comparison considers
157186
context, abbreviations, synonyms, and any provided evaluation guidelines.
158-
187+
159188
Args:
160189
value1: First value to compare. Can be any type that converts to string.
161190
value2: Second value to compare. Can be any type that converts to string.
162-
191+
163192
Returns:
164193
float: Binary similarity score:
165194
- 1.0 if the LLM determines the values are equivalent
166195
- 0.0 if the LLM determines the values are not equivalent
167196
- 0.0 if an error occurs during comparison
168-
197+
169198
Note:
170199
- None values: Returns 1.0 if both are None, 0.0 if only one is None
171200
- Error handling: Returns 0.0 for any exceptions during LLM calls
172201
- Cost consideration: Each call incurs API costs and latency
173-
202+
174203
Example:
175204
>>> comparator = LLMComparator()
176205
>>> comparator.compare("St. John's Street", "Saint John's St")
@@ -190,51 +219,50 @@ def compare(self, value1: Any, value2: Any) -> float:
190219
formatted_prompt = self.prompt_template.render(
191220
value1=html.escape(str(value1)),
192221
value2=html.escape(str(value2)),
193-
eval_guidelines=self.eval_guidelines
222+
eval_guidelines=self.eval_guidelines,
194223
)
195-
224+
196225
try:
197226
# Get LLM response
198227
response = self._invoke_agent(formatted_prompt)
199228
# Parse response to boolean
200229
response_lower = response.strip().lower()
201-
if 'true' in response_lower:
230+
if "true" in response_lower:
202231
return 1.0
203232
else:
204233
return 0.0
205-
234+
206235
except NoCredentialsError:
207236
print("Error: AWS credentials not found.")
208-
raise
237+
raise
209238

210239
except Exception as e:
211240
print(f"Error during LLM call: {e}")
212241
raise
213242

214-
215243
def get_comparison_details(self, value1: Any, value2: Any) -> Dict[str, Any]:
216244
"""Get detailed information about a comparison operation.
217-
245+
218246
This method provides comprehensive details about the comparison process,
219247
including the formatted prompt, LLM response, model information, and
220248
final comparison result. Useful for debugging, auditing, and understanding
221249
how the LLM made its decision.
222-
250+
223251
Args:
224252
value1: First value to compare. Can be any type that converts to string.
225253
value2: Second value to compare. Can be any type that converts to string.
226-
254+
227255
Returns:
228256
Dict[str, Any]: Dictionary containing comparison details:
229257
- 'prompt' (str): The formatted prompt sent to the LLM
230258
- 'llm_response' (str): Raw response from the LLM
231259
- 'model_id' (Union[Model, str]): The model used (string ID or Model instance)
232260
- 'comparison_result' (float): Final similarity score (0.0 or 1.0)
233-
261+
234262
On error:
235263
- 'error' (str): Error message describing what went wrong
236264
- 'comparison_result' (bool): False to indicate failure
237-
265+
238266
Example:
239267
>>> comparator = LLMComparator(eval_guidelines="Consider abbreviations")
240268
>>> details = comparator.get_comparison_details("St. John", "Saint John")
@@ -248,19 +276,16 @@ def get_comparison_details(self, value1: Any, value2: Any) -> Dict[str, Any]:
248276
formatted_prompt = self.prompt_template.render(
249277
value1=html.escape(str(value1)),
250278
value2=html.escape(str(value2)),
251-
eval_guidelines=self.eval_guidelines
279+
eval_guidelines=self.eval_guidelines,
252280
)
253-
281+
254282
try:
255283
response = self._invoke_agent(formatted_prompt)
256284
return {
257285
"prompt": formatted_prompt,
258286
"llm_response": response,
259287
"model_id": self.model,
260-
"comparison_result": self.compare(value1, value2)
288+
"comparison_result": self.compare(value1, value2),
261289
}
262290
except Exception as e:
263-
return {
264-
"error": str(e),
265-
"comparison_result": False
266-
}
291+
return {"error": str(e), "comparison_result": False}

0 commit comments

Comments
 (0)