|
1 | 1 | from typing import Annotated, TypedDict |
2 | 2 |
|
3 | | -from langchain_core.documents import Document |
| 3 | +from langchain_core.embeddings import Embeddings |
| 4 | +from langchain_core.language_models.chat_models import BaseChatModel |
4 | 5 | from langchain_core.messages import BaseMessage |
| 6 | +from langchain_core.runnables import Runnable, RunnableConfig |
5 | 7 | from langgraph.graph.message import add_messages |
6 | 8 |
|
| 9 | +from agent.tasks.rephrase import create_rephrase_chain |
7 | 10 | from tools.external_search.state import WebSearchResult |
8 | 11 |
|
9 | 12 |
|
10 | | -class AdditionalContent(TypedDict): |
| 13 | +class AdditionalContent(TypedDict, total=False): |
11 | 14 | search_results: list[WebSearchResult] |
12 | 15 |
|
13 | 16 |
|
14 | | -class BaseState(TypedDict): |
15 | | - # (Everything the Chainlit layer uses should be included here) |
16 | | - |
| 17 | +class InputState(TypedDict, total=False): |
17 | 18 | user_input: str # User input text |
18 | | - chat_history: Annotated[list[BaseMessage], add_messages] |
19 | | - context: list[Document] |
| 19 | + |
| 20 | + |
| 21 | +class OutputState(TypedDict, total=False): |
20 | 22 | answer: str # primary LLM response that is streamed to the user |
21 | 23 | additional_content: AdditionalContent # sends on graph completion |
22 | 24 |
|
23 | 25 |
|
| 26 | +class BaseState(InputState, OutputState, total=False): |
| 27 | + rephrased_input: str # LLM-generated query from user input |
| 28 | + chat_history: Annotated[list[BaseMessage], add_messages] |
| 29 | + |
| 30 | + |
24 | 31 | class BaseGraphBuilder: |
25 | | - pass # NOTE: Anything that is common to all graph builders goes here |
| 32 | + # NOTE: Anything that is common to all graph builders goes here |
| 33 | + |
| 34 | + def __init__( |
| 35 | + self, |
| 36 | + llm: BaseChatModel, |
| 37 | + embedding: Embeddings, |
| 38 | + ) -> None: |
| 39 | + self.rephrase_chain: Runnable = create_rephrase_chain(llm) |
| 40 | + |
| 41 | + async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState: |
| 42 | + rephrased_input: str = await self.rephrase_chain.ainvoke( |
| 43 | + { |
| 44 | + "user_input": state["user_input"], |
| 45 | + "chat_history": state["chat_history"], |
| 46 | + }, |
| 47 | + config, |
| 48 | + ) |
| 49 | + return BaseState(rephrased_input=rephrased_input) |
0 commit comments