Skip to content

Commit f413b2b

Browse files
authored
FIRE-572 FIRE-589 | Add support for tsq detector in python sdk (#262)
* Add support for tsq detector in python sqk * Update pydoc * Update pydoc and fix ToolDefinition name * CR * Bump version 0.8.0 -> 0.9.0 * Improve test coverage * Fix tests * Fix tracing base_url
1 parent d2dc48d commit f413b2b

6 files changed

Lines changed: 244 additions & 23 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"
55

66
[tool.poetry]
77
name = "qualifire"
8-
version = "0.8.0"
8+
version = "0.9.0"
99
description = "Qualifire Python SDK"
1010
readme = "README.md"
1111
authors = ["qualifire-dev <dror@qualifire.ai>"]

qualifire/client.py

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66

77
import requests
88

9-
from .types import EvaluationRequest, EvaluationResponse, LLMMessage, SyntaxCheckArgs
9+
from .types import (
10+
EvaluationRequest,
11+
EvaluationResponse,
12+
LLMMessage,
13+
LLMToolDefinition,
14+
SyntaxCheckArgs,
15+
)
1016
from .utils import get_api_key, get_base_url
1117

1218
logger = logging.getLogger("qualifire")
@@ -27,26 +33,32 @@ def __init__(
2733

2834
def evaluate(
2935
self,
30-
input: str,
31-
output: str,
36+
input: Optional[str] = None,
37+
output: Optional[str] = None,
38+
messages: Optional[List[LLMMessage]] = None,
39+
available_tools: Optional[List[LLMToolDefinition]] = None,
3240
assertions: Optional[List[str]] = None,
3341
dangerous_content_check: bool = False,
3442
grounding_check: bool = False,
3543
hallucinations_check: bool = False,
3644
harassment_check: bool = False,
3745
hate_speech_check: bool = False,
3846
instructions_following_check: bool = False,
39-
messages: Optional[List[LLMMessage]] = None,
4047
pii_check: bool = False,
4148
prompt_injections: bool = False,
4249
sexual_content_check: bool = False,
4350
syntax_checks: Optional[Dict[str, SyntaxCheckArgs]] = None,
51+
tool_selection_quality_check: bool = False,
4452
) -> Union[EvaluationResponse, None]:
4553
"""
4654
Evaluates the given input and output pairs.
4755
4856
:param input: The primary input for the evaluation.
4957
:param output: The primary output (e.g., LLM response) to evaluate.
58+
:param messages: List of message objects representing conversation history.
59+
Must be set if tool_selection_quality_check is True.
60+
:param available_tools: List of available tools.
61+
Must be set if tool_selection_quality_check is True.
5062
:param assertions: A list of custom assertions to check against the output.
5163
:param dangerous_content_check: Check for dangerous content generation.
5264
:param grounding_check: Check if the output is grounded in the provided
@@ -56,11 +68,12 @@ def evaluate(
5668
:param hate_speech_check: Check for hate speech.
5769
:param instructions_following_check: Check if the output follows instructions
5870
in the input/messages.
59-
:param messages: List of message objects representing conversation history.
6071
:param pii_check: Check for personally identifiable information.
6172
:param prompt_injections: Check for attempts at prompt injection.
6273
:param sexual_content_check: Check for sexually explicit content.
6374
:param syntax_checks: Dictionary defining syntax checks (e.g., JSON, SQL).
75+
:param tool_selection_quality_check: Check for tool selection quality.
76+
Only works when `available_tools` and `messages` are provided.
6477
6578
:return: An EvaluationResponse object containing the evaluation results.
6679
:raises Exception: If an error occurs during the evaluation.
@@ -95,27 +108,83 @@ def evaluate(
95108
sexual_content_check=True,
96109
syntax_checks={
97110
"json": SyntaxCheckArgs(args="strict") # Example syntax check
98-
}
111+
},
99112
)
100113
```
101-
"""
102114
115+
Example with tools:
116+
```python
117+
from qualifire import Qualifire
118+
from qualifire.types import LLMMessage, LLMToolDefinition
119+
120+
qualifire = Qualifire(api_key="your_api_key")
121+
122+
evaluation_response = qualifire.evaluate(
123+
messages=[
124+
LLMMessage(
125+
role="user",
126+
content="What is the weather tomorrow in New York?",
127+
),
128+
LLMMessage(
129+
role="assistant",
130+
content='please run the following tool'
131+
tool_calls=[
132+
LLMToolCall(
133+
"id": "tool_call_id",
134+
"name": "get_weather_forecast",
135+
"arguments": {
136+
"location": "New York, NY",
137+
"date": "tomorrow",
138+
},
139+
),
140+
],
141+
),
142+
],
143+
available_tools=[
144+
LLMToolDefinition(
145+
name="get_weather_forecast",
146+
description="Provides the weather forecast for a given location and date.",
147+
parameters={
148+
"type": "object",
149+
"properties": {
150+
"location": {
151+
"type": "string",
152+
"description": "The city and state, e.g., San Francisco, CA",
153+
},
154+
"date": {
155+
"type": "string",
156+
"description": "The date for the forecast, e.g., tomorrow, or YYYY-MM-DD",
157+
},
158+
},
159+
"required": [
160+
"location",
161+
"date",
162+
],
163+
},
164+
),
165+
],
166+
tool_selection_quality_check=True,
167+
)
168+
```
169+
""" # noqa E501
103170
url = f"{self._base_url}/api/evaluation/evaluate"
104171
request = EvaluationRequest(
105172
input=input,
106173
output=output,
174+
messages=messages,
175+
available_tools=available_tools,
107176
assertions=assertions,
108177
dangerous_content_check=dangerous_content_check,
109178
grounding_check=grounding_check,
110179
hallucinations_check=hallucinations_check,
111180
harassment_check=harassment_check,
112181
hate_speech_check=hate_speech_check,
113182
instructions_following_check=instructions_following_check,
114-
messages=messages,
115183
pii_check=pii_check,
116184
prompt_injections=prompt_injections,
117185
sexual_content_check=sexual_content_check,
118186
syntax_checks=syntax_checks,
187+
tool_selection_quality_check=tool_selection_quality_check,
119188
)
120189

121190
# Filter out None values before dumping to JSON

qualifire/tracer_init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __configure_tracer(api_key: str) -> None:
3737
__suppress_prints(
3838
Traceloop.init,
3939
app_name="qualifire-agent",
40-
api_endpoint=f"{get_base_url()}/telemetry", # /v1/traces is automatically added # noqa: E501
40+
api_endpoint=f"{get_base_url()}/api/telemetry", # /v1/traces is automatically added # noqa: E501
4141
headers={"X-Qualifire-API-Key": api_key},
4242
telemetry_enabled=False,
4343
traceloop_sync_enabled=False,

qualifire/types.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,27 @@
1-
from typing import Dict, List, Optional
1+
from typing import Any, Dict, List, Optional
22

33
from dataclasses import dataclass, field
44

55

6+
@dataclass
7+
class LLMToolDefinition:
8+
name: str
9+
description: str
10+
parameters: Dict[str, Any]
11+
12+
13+
@dataclass
14+
class LLMToolCall:
15+
name: str
16+
arguments: Dict[str, Any]
17+
id: Optional[str]
18+
19+
620
@dataclass
721
class LLMMessage:
8-
content: str
922
role: str
23+
content: str
24+
tool_calls: Optional[List[LLMToolCall]] = None
1025

1126

1227
@dataclass
@@ -16,20 +31,44 @@ class SyntaxCheckArgs:
1631

1732
@dataclass
1833
class EvaluationRequest:
19-
input: str
20-
output: str
21-
dangerous_content_check: bool
22-
hallucinations_check: bool
23-
harassment_check: bool
24-
hate_speech_check: bool
25-
pii_check: bool
26-
prompt_injections: bool
27-
sexual_content_check: bool
34+
input: Optional[str] = None
35+
output: Optional[str] = None
36+
messages: Optional[List[LLMMessage]] = field(default_factory=list)
37+
available_tools: Optional[List[LLMToolDefinition]] = None
38+
dangerous_content_check: bool = False
39+
hallucinations_check: bool = False
40+
harassment_check: bool = False
41+
hate_speech_check: bool = False
42+
pii_check: bool = False
43+
prompt_injections: bool = False
44+
sexual_content_check: bool = False
2845
grounding_check: bool = False
2946
instructions_following_check: bool = False
3047
syntax_checks: Optional[Dict[str, SyntaxCheckArgs]] = None
31-
messages: Optional[List[LLMMessage]] = field(default_factory=list)
3248
assertions: Optional[List[str]] = field(default_factory=list)
49+
tool_selection_quality_check: bool = False
50+
51+
def __post_init__(self):
52+
self._validate_messages_input_output()
53+
self._validate_tsq_requirements()
54+
55+
def _validate_messages_input_output(self):
56+
if not self.messages and not self.input and not self.output:
57+
raise ValueError(
58+
"At least one of messages, input, or output must be set",
59+
)
60+
61+
def _validate_tsq_requirements(self):
62+
if self.tool_selection_quality_check and not self.messages:
63+
raise ValueError(
64+
"messages must be provided in conjunction "
65+
"with tool_selection_quality_check=True."
66+
)
67+
if self.tool_selection_quality_check and not self.available_tools:
68+
raise ValueError(
69+
"available_tools must be provided in conjunction "
70+
"with tool_selection_quality_check=True."
71+
)
3372

3473

3574
@dataclass

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
if __name__ == "__main__":
55
setup(
66
name="qualifire",
7-
version="0.8.0",
7+
version="0.9.0",
88
description="Qualifire Python SDK",
99
author="qualifire-dev",
1010
author_email="dror@qualifire.ai",

tests/test_types.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import contextlib
2+
3+
import pytest
4+
5+
from qualifire.types import EvaluationRequest, LLMMessage, LLMToolDefinition
6+
7+
_test_llm_messages = [
8+
LLMMessage(
9+
role="user",
10+
content="test",
11+
),
12+
]
13+
14+
_test_available_tools = [
15+
LLMToolDefinition(
16+
name="foo",
17+
description="foo tool function definition",
18+
parameters={
19+
"type": "object",
20+
"properties": {
21+
"bar": {
22+
"type": "string",
23+
},
24+
"baz": {
25+
"type": "integer",
26+
},
27+
},
28+
"required": ["bar", "baz"],
29+
},
30+
)
31+
]
32+
33+
34+
class TestEvaluationRequest:
35+
@pytest.mark.parametrize(
36+
"messages,input_,output,expected_error",
37+
[
38+
(None, None, None, True),
39+
([], None, None, True),
40+
(None, "", None, True),
41+
(None, None, "", True),
42+
(_test_llm_messages, None, None, False),
43+
(_test_llm_messages, "", None, False),
44+
(_test_llm_messages, None, "", False),
45+
(_test_llm_messages, "", "", False),
46+
(None, "input", None, False),
47+
(None, "input", "", False),
48+
([], "input", None, False),
49+
([], "input", "", False),
50+
(None, None, "output", False),
51+
(None, "", "output", False),
52+
([], None, "output", False),
53+
([], "", "output", False),
54+
(_test_llm_messages, "input", None, False),
55+
(_test_llm_messages, "input", "", False),
56+
(_test_llm_messages, None, "output", False),
57+
(_test_llm_messages, "", "output", False),
58+
(None, "input", "output", False),
59+
([], "input", "output", False),
60+
(_test_llm_messages, "input", "output", False),
61+
],
62+
)
63+
def test_validate_messages_input_output(
64+
self,
65+
messages,
66+
input_,
67+
output,
68+
expected_error,
69+
):
70+
with pytest.raises(ValueError) if expected_error else contextlib.nullcontext():
71+
EvaluationRequest(
72+
messages=messages,
73+
input=input_,
74+
output=output,
75+
)
76+
77+
@pytest.mark.parametrize(
78+
"tsq_check,messages,available_tools,expected_error",
79+
[
80+
(True, None, None, True),
81+
(True, [], None, True),
82+
(True, None, [], True),
83+
(True, [], [], True),
84+
(True, _test_llm_messages, None, True),
85+
(True, _test_llm_messages, [], True),
86+
(True, None, _test_available_tools, True),
87+
(True, [], _test_available_tools, True),
88+
(True, _test_llm_messages, _test_available_tools, False),
89+
(False, None, None, False),
90+
(False, [], None, False),
91+
(False, None, [], False),
92+
(False, [], [], False),
93+
(False, _test_llm_messages, None, False),
94+
(False, _test_llm_messages, [], False),
95+
(False, None, _test_available_tools, False),
96+
(False, [], _test_available_tools, False),
97+
(False, _test_llm_messages, _test_available_tools, False),
98+
],
99+
)
100+
def test_validate_tsq_requirements(
101+
self,
102+
tsq_check,
103+
messages,
104+
available_tools,
105+
expected_error,
106+
):
107+
with pytest.raises(ValueError) if expected_error else contextlib.nullcontext():
108+
EvaluationRequest(
109+
input="input", # To pass the messages-input-output validation
110+
messages=messages,
111+
available_tools=available_tools,
112+
tool_selection_quality_check=tsq_check,
113+
)

0 commit comments

Comments
 (0)