Skip to content

Commit ba01931

Browse files
committed
code quality fixes
1 parent f35f3e0 commit ba01931

3 files changed

Lines changed: 22 additions & 61 deletions

File tree

src/agent/profiles/base.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,12 @@ class BaseState(InputState, OutputState, total=False):
5252
class BaseGraphBuilder:
5353
"""Base class for all graph builders with common preprocessing and postprocessing."""
5454

55-
def __init__(
56-
self,
57-
llm: BaseChatModel,
58-
embedding: Embeddings
59-
) -> None:
55+
def __init__(self, llm: BaseChatModel, embedding: Embeddings) -> None:
6056
"""Initialize with LLM and embedding models."""
6157
self.preprocessing_workflow: Runnable = create_preprocessing_workflow(llm)
6258
self.search_workflow: Runnable = create_search_workflow(llm)
6359

64-
async def preprocess(
65-
self,
66-
state: BaseState,
67-
config: RunnableConfig
68-
) -> BaseState:
60+
async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState:
6961
"""Run the complete preprocessing workflow and map results to state."""
7062
result: PreprocessingState = await self.preprocessing_workflow.ainvoke(
7163
PreprocessingState(
@@ -77,10 +69,7 @@ async def preprocess(
7769

7870
return self._map_preprocessing_result(result)
7971

80-
def _map_preprocessing_result(
81-
self,
82-
result: PreprocessingState
83-
) -> BaseState:
72+
def _map_preprocessing_result(self, result: PreprocessingState) -> BaseState:
8473
"""Map preprocessing results to BaseState with defaults."""
8574
return BaseState(
8675
rephrased_input=result["rephrased_input"],
@@ -90,11 +79,7 @@ def _map_preprocessing_result(
9079
detected_language=result.get("detected_language", DEFAULT_LANGUAGE),
9180
)
9281

93-
async def postprocess(
94-
self,
95-
state: BaseState,
96-
config: RunnableConfig
97-
) -> BaseState:
82+
async def postprocess(self, state: BaseState, config: RunnableConfig) -> BaseState:
9883
"""Postprocess that preserves existing state and conditionally adds search results."""
9984
search_results: list[WebSearchResult] = []
10085

@@ -113,4 +98,9 @@ async def postprocess(
11398
search_results = result["search_results"]
11499

115100
# Create new state with updated additional_content
116-
return BaseState(**{**state, "additional_content": AdditionalContent(search_results=search_results)})
101+
return BaseState(
102+
**{
103+
**state,
104+
"additional_content": AdditionalContent(search_results=search_results),
105+
}
106+
)

src/agent/profiles/react_to_me.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,7 @@ class ReactToMeState(BaseState):
2323
class ReactToMeGraphBuilder(BaseGraphBuilder):
2424
"""Graph builder for ReactToMe profile with Reactome-specific functionality."""
2525

26-
def __init__(
27-
self,
28-
llm: BaseChatModel,
29-
embedding: Embeddings
30-
) -> None:
26+
def __init__(self, llm: BaseChatModel, embedding: Embeddings) -> None:
3127
"""Initialize ReactToMe graph builder with required components."""
3228
super().__init__(llm, embedding)
3329

@@ -69,25 +65,20 @@ def _build_workflow(self) -> StateGraph:
6965
return state_graph
7066

7167
async def preprocess(
72-
self,
73-
state: ReactToMeState,
74-
config: RunnableConfig
68+
self, state: ReactToMeState, config: RunnableConfig
7569
) -> ReactToMeState:
7670
"""Run preprocessing workflow."""
7771
result = await super().preprocess(state, config)
7872
return ReactToMeState(**result)
7973

8074
async def proceed_with_research(
81-
self,
82-
state: ReactToMeState
75+
self, state: ReactToMeState
8376
) -> Literal["Continue", "Finish"]:
8477
"""Determine whether to proceed with research based on safety check."""
8578
return "Continue" if state["safety"] == SAFETY_SAFE else "Finish"
8679

8780
async def generate_unsafe_response(
88-
self,
89-
state: ReactToMeState,
90-
config: RunnableConfig
81+
self, state: ReactToMeState, config: RunnableConfig
9182
) -> ReactToMeState:
9283
"""Generate appropriate refusal response for unsafe queries."""
9384
final_answer_message = await self.unsafe_answer_generator.ainvoke(
@@ -120,9 +111,7 @@ async def generate_unsafe_response(
120111
)
121112

122113
async def call_model(
123-
self,
124-
state: ReactToMeState,
125-
config: RunnableConfig
114+
self, state: ReactToMeState, config: RunnableConfig
126115
) -> ReactToMeState:
127116
"""Generate response using Reactome RAG for safe queries."""
128117
result: dict[str, Any] = await self.reactome_rag.ainvoke(
@@ -147,9 +136,6 @@ async def call_model(
147136
)
148137

149138

150-
def create_reactome_graph(
151-
llm: BaseChatModel,
152-
embedding: Embeddings
153-
) -> StateGraph:
139+
def create_reactome_graph(llm: BaseChatModel, embedding: Embeddings) -> StateGraph:
154140
"""Create and return the ReactToMe workflow graph."""
155141
return ReactToMeGraphBuilder(llm, embedding).uncompiled_graph

src/retrievers/csv_chroma.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,7 @@ def list_chroma_subdirectories(directory: Path) -> List[str]:
7777
class HybridRetriever:
7878
"""Advanced hybrid retriever supporting RRF, parallel processing, and multi-source search."""
7979

80-
def __init__(
81-
self,
82-
embedding: Embeddings,
83-
embeddings_directory: Path
84-
):
80+
def __init__(self, embedding: Embeddings, embeddings_directory: Path):
8581

8682
self.embedding = embedding
8783
self.embeddings_directory = embeddings_directory
@@ -159,25 +155,17 @@ def _create_vector_retriever(self, subdirectory: str) -> Optional[object]:
159155
return None
160156

161157
async def _search_with_bm25(
162-
self,
163-
query: str,
164-
retriever: BM25Retriever
158+
self, query: str, retriever: BM25Retriever
165159
) -> List[Document]:
166160
"""Search using BM25 retriever asynchronously."""
167161
return await asyncio.to_thread(retriever.get_relevant_documents, query)
168162

169-
async def _search_with_vector(
170-
self,
171-
query: str,
172-
retriever: Any
173-
) -> List[Document]:
163+
async def _search_with_vector(self, query: str, retriever: Any) -> List[Document]:
174164
"""Search using vector retriever asynchronously."""
175165
return await asyncio.to_thread(retriever.get_relevant_documents, query)
176166

177167
async def _execute_hybrid_search(
178-
self,
179-
query: str,
180-
subdirectory: str
168+
self, query: str, subdirectory: str
181169
) -> List[Document]:
182170
"""Execute hybrid search (BM25 + vector) for a single query on a subdirectory."""
183171
retriever_info = self._retrievers.get(subdirectory)
@@ -223,9 +211,7 @@ def _generate_document_identifier(self, document: Document) -> str:
223211
return hashlib.md5(document.page_content.encode()).hexdigest()
224212

225213
async def _apply_reciprocal_rank_fusion(
226-
self,
227-
queries: List[str],
228-
subdirectory: str
214+
self, queries: List[str], subdirectory: str
229215
) -> List[Document]:
230216
"""Apply Reciprocal Rank Fusion to results from multiple queries on a subdirectory."""
231217
logger.info(
@@ -325,8 +311,7 @@ async def ainvoke(self, inputs: Dict[str, Any]) -> str:
325311

326312

327313
def create_hybrid_retriever(
328-
embedding: Embeddings,
329-
embeddings_directory: Path
314+
embedding: Embeddings, embeddings_directory: Path
330315
) -> HybridRetriever:
331316
"""Create a hybrid retriever with RRF and parallel processing support."""
332317
try:

0 commit comments

Comments
 (0)