forked from i-am-bee/beeai-framework
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsearx_agent.py
More file actions
96 lines (76 loc) · 3.01 KB
/
searx_agent.py
File metadata and controls
96 lines (76 loc) · 3.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import asyncio
import sys
import traceback
from langchain_community.utilities import SearxSearchWrapper
from pydantic import BaseModel, Field
from beeai_framework.adapters.ollama import OllamaChatModel
from beeai_framework.backend import ChatModelOutput, ChatModelStructureOutput, UserMessage
from beeai_framework.errors import FrameworkError
from beeai_framework.template import PromptTemplate, PromptTemplateInput
from beeai_framework.workflows import Workflow
async def main() -> None:
llm = OllamaChatModel("granite3.1-dense:8b")
search = SearxSearchWrapper(searx_host="http://127.0.0.1:8888")
class State(BaseModel):
input: str
search_results: str | None = None
output: str | None = None
class InputSchema(BaseModel):
input: str
class WebSearchQuery(BaseModel):
search_query: str = Field(description="Search query.")
class RAGSchema(InputSchema):
input: str
search_results: str
async def web_search(state: State) -> str:
print("Step: ", sys._getframe().f_code.co_name)
prompt = PromptTemplate(
PromptTemplateInput(
schema=InputSchema,
template="""
Please create a web search query for the following input.
Query: {{input}}""",
)
).render(InputSchema(input=state.input))
output: ChatModelStructureOutput = await llm.create_structure(
schema=WebSearchQuery, messages=[UserMessage(prompt)]
)
# TODO Why is object not of type schema T?
state.search_results = search.run(f"current weather in {output.object['search_query']}")
return Workflow.NEXT
async def generate_output(state: State) -> str:
print("Step: ", sys._getframe().f_code.co_name)
prompt = PromptTemplate(
PromptTemplateInput(
schema=RAGSchema,
template="""
Use the following search results to answer the query accurately. If the results are irrelevant or insufficient, say 'I don't know.'
Search Results:
{{search_results}}
Query: {{input}}
""", # noqa: E501
)
).render(
RAGSchema(
input=state.input,
search_results=state.search_results or "No results available.",
)
)
output: ChatModelOutput = await llm.create(messages=[UserMessage(prompt)])
state.output = output.get_text_content()
return Workflow.END
# Define the structure of the workflow graph
workflow = Workflow(State)
workflow.add_step("web_search", web_search)
workflow.add_step("generate_output", generate_output)
# Execute the workflow
result = await workflow.run(State(input="What is the demon core?"))
print("\n*********************")
print("Input: ", result.state.input)
print("Agent: ", result.state.output)
if __name__ == "__main__":
try:
asyncio.run(main())
except FrameworkError as e:
traceback.print_exc()
sys.exit(e.explain())