Skip to content

Commit 3052b45

Browse files
authored
Merge pull request #110 from hendo-21/fix/ruff-failures
fix ruff failures
2 parents cf4c3e3 + 5eaa2e5 commit 3052b45

33 files changed

Lines changed: 735 additions & 474 deletions

src/fenn/agents/__init__.py

Lines changed: 150 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,124 +1,224 @@
1-
import asyncio, warnings, copy, time
1+
import asyncio
2+
import copy
3+
import time
4+
import warnings
25

36
_TERMINAL = object() # sentinel for explicit terminal transitions in Flow.connect
47

58

69
class BaseNode:
7-
def __init__(self): self.params, self.successors = {}, {}
8-
def set_params(self, params): self.params = params
9-
def prep(self, shared): pass
10-
def exec(self, prep_res): pass
11-
def post(self, shared, prep_res, exec_res): pass
12-
def _exec(self, prep_res): return self.exec(prep_res)
13-
def _run(self, shared): p = self.prep(shared); e = self._exec(p); return self.post(shared, p, e)
10+
def __init__(self):
11+
self.params, self.successors = {}, {}
12+
13+
def set_params(self, params):
14+
self.params = params
15+
16+
def prep(self, shared):
17+
pass
18+
19+
def exec(self, prep_res):
20+
pass
21+
22+
def post(self, shared, prep_res, exec_res):
23+
pass
24+
25+
def _exec(self, prep_res):
26+
return self.exec(prep_res)
27+
28+
def _run(self, shared):
29+
p = self.prep(shared)
30+
e = self._exec(p)
31+
return self.post(shared, p, e)
32+
1433
def run(self, shared):
15-
if self.successors: warnings.warn("Node won't run successors. Use Flow.")
34+
if self.successors:
35+
warnings.warn("Node won't run successors. Use Flow.")
1636
return self._run(shared)
1737

1838

1939
class Node(BaseNode):
20-
def __init__(self, max_retries=1, wait=0): super().__init__(); self.max_retries, self.wait = max_retries, wait
21-
def exec_fallback(self, prep_res, exc): raise exc
40+
def __init__(self, max_retries=1, wait=0):
41+
super().__init__()
42+
self.max_retries, self.wait = max_retries, wait
43+
44+
def exec_fallback(self, prep_res, exc):
45+
raise exc
46+
2247
def _exec(self, prep_res):
2348
for self.cur_retry in range(self.max_retries):
24-
try: return self.exec(prep_res)
49+
try:
50+
return self.exec(prep_res)
2551
except Exception as e:
26-
if self.cur_retry == self.max_retries - 1: return self.exec_fallback(prep_res, e)
27-
if self.wait > 0: time.sleep(self.wait)
52+
if self.cur_retry == self.max_retries - 1:
53+
return self.exec_fallback(prep_res, e)
54+
if self.wait > 0:
55+
time.sleep(self.wait)
2856

2957

3058
class BatchNode(Node):
31-
def _exec(self, items): return [super(BatchNode, self)._exec(i) for i in (items or [])]
59+
def _exec(self, items):
60+
return [super(BatchNode, self)._exec(i) for i in (items or [])]
3261

3362

3463
class Flow(BaseNode):
35-
def __init__(self, start=None): super().__init__(); self.start_node = start
64+
def __init__(self, start=None):
65+
super().__init__()
66+
self.start_node = start
3667

37-
def start(self, start): self.start_node = start; return start
68+
def start(self, start):
69+
self.start_node = start
70+
return start
3871

3972
def connect(self, src, dst, action="default"):
40-
if action in src.successors: warnings.warn(f"Overwriting successor for action '{action}'")
73+
if action in src.successors:
74+
warnings.warn(f"Overwriting successor for action '{action}'")
4175
src.successors[action] = _TERMINAL if dst is None else dst
4276
return self
4377

4478
def get_next_node(self, curr, action):
4579
nxt = curr.successors.get(action or "default")
46-
if nxt is _TERMINAL: return None
47-
if nxt is None and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
80+
if nxt is _TERMINAL:
81+
return None
82+
if nxt is None and curr.successors:
83+
warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
4884
return nxt
4985

5086
def _orch(self, shared, params=None):
51-
curr, p, last_action = copy.copy(self.start_node), (params or {**self.params}), None
52-
while curr: curr.set_params(p); last_action = curr._run(shared); curr = copy.copy(self.get_next_node(curr, last_action))
87+
curr, p, last_action = (
88+
copy.copy(self.start_node),
89+
(params or {**self.params}),
90+
None,
91+
)
92+
while curr:
93+
curr.set_params(p)
94+
last_action = curr._run(shared)
95+
curr = copy.copy(self.get_next_node(curr, last_action))
5396
return last_action
5497

55-
def _run(self, shared): p = self.prep(shared); o = self._orch(shared); return self.post(shared, p, o)
56-
def post(self, shared, prep_res, exec_res): return exec_res
98+
def _run(self, shared):
99+
p = self.prep(shared)
100+
o = self._orch(shared)
101+
return self.post(shared, p, o)
102+
103+
def post(self, shared, prep_res, exec_res):
104+
return exec_res
57105

58106

59107
class BatchFlow(Flow):
60108
def _run(self, shared):
61109
pr = self.prep(shared) or []
62-
for bp in pr: self._orch(shared, {**self.params, **bp})
110+
for bp in pr:
111+
self._orch(shared, {**self.params, **bp})
63112
return self.post(shared, pr, None)
64113

65114

66115
class AsyncNode(Node):
67-
async def prep_async(self, shared): pass
68-
async def exec_async(self, prep_res): pass
69-
async def exec_fallback_async(self, prep_res, exc): raise exc
70-
async def post_async(self, shared, prep_res, exec_res): pass
116+
async def prep_async(self, shared):
117+
pass
118+
119+
async def exec_async(self, prep_res):
120+
pass
121+
122+
async def exec_fallback_async(self, prep_res, exc):
123+
raise exc
124+
125+
async def post_async(self, shared, prep_res, exec_res):
126+
pass
127+
71128
async def _exec(self, prep_res):
72129
for self.cur_retry in range(self.max_retries):
73-
try: return await self.exec_async(prep_res)
130+
try:
131+
return await self.exec_async(prep_res)
74132
except Exception as e:
75-
if self.cur_retry == self.max_retries - 1: return await self.exec_fallback_async(prep_res, e)
76-
if self.wait > 0: await asyncio.sleep(self.wait)
133+
if self.cur_retry == self.max_retries - 1:
134+
return await self.exec_fallback_async(prep_res, e)
135+
if self.wait > 0:
136+
await asyncio.sleep(self.wait)
137+
77138
async def run_async(self, shared):
78-
if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.")
139+
if self.successors:
140+
warnings.warn("Node won't run successors. Use AsyncFlow.")
79141
return await self._run_async(shared)
80-
async def _run_async(self, shared): p = await self.prep_async(shared); e = await self._exec(p); return await self.post_async(shared, p, e)
81-
def _run(self, shared): raise RuntimeError("Use run_async.")
142+
143+
async def _run_async(self, shared):
144+
p = await self.prep_async(shared)
145+
e = await self._exec(p)
146+
return await self.post_async(shared, p, e)
147+
148+
def _run(self, shared):
149+
raise RuntimeError("Use run_async.")
82150

83151

84152
class AsyncBatchNode(AsyncNode, BatchNode):
85-
async def _exec(self, items): return [await super(AsyncBatchNode, self)._exec(i) for i in items]
153+
async def _exec(self, items):
154+
return [await super(AsyncBatchNode, self)._exec(i) for i in items]
86155

87156

88157
class AsyncParallelBatchNode(AsyncNode, BatchNode):
89-
async def _exec(self, items): return await asyncio.gather(*(super(AsyncParallelBatchNode, self)._exec(i) for i in items))
158+
async def _exec(self, items):
159+
return await asyncio.gather(
160+
*(super(AsyncParallelBatchNode, self)._exec(i) for i in items)
161+
)
90162

91163

92164
class AsyncFlow(Flow, AsyncNode):
93165
async def _orch_async(self, shared, params=None):
94-
curr, p, last_action = copy.copy(self.start_node), (params or {**self.params}), None
95-
while curr: curr.set_params(p); last_action = await curr._run_async(shared) if isinstance(curr, AsyncNode) else curr._run(shared); curr = copy.copy(self.get_next_node(curr, last_action))
166+
curr, p, last_action = (
167+
copy.copy(self.start_node),
168+
(params or {**self.params}),
169+
None,
170+
)
171+
while curr:
172+
curr.set_params(p)
173+
last_action = (
174+
await curr._run_async(shared)
175+
if isinstance(curr, AsyncNode)
176+
else curr._run(shared)
177+
)
178+
curr = copy.copy(self.get_next_node(curr, last_action))
96179
return last_action
97-
async def _run_async(self, shared): p = await self.prep_async(shared); o = await self._orch_async(shared); return await self.post_async(shared, p, o)
98-
async def post_async(self, shared, prep_res, exec_res): return exec_res
180+
181+
async def _run_async(self, shared):
182+
p = await self.prep_async(shared)
183+
o = await self._orch_async(shared)
184+
return await self.post_async(shared, p, o)
185+
186+
async def post_async(self, shared, prep_res, exec_res):
187+
return exec_res
99188

100189

101190
class AsyncBatchFlow(AsyncFlow, BatchFlow):
102191
async def _run_async(self, shared):
103192
pr = await self.prep_async(shared) or []
104-
for bp in pr: await self._orch_async(shared, {**self.params, **bp})
193+
for bp in pr:
194+
await self._orch_async(shared, {**self.params, **bp})
105195
return await self.post_async(shared, pr, None)
106196

107197

108198
class AsyncParallelBatchFlow(AsyncFlow, BatchFlow):
109199
async def _run_async(self, shared):
110200
pr = await self.prep_async(shared) or []
111-
await asyncio.gather(*(self._orch_async(shared, {**self.params, **bp}) for bp in pr))
201+
await asyncio.gather(
202+
*(self._orch_async(shared, {**self.params, **bp}) for bp in pr)
203+
)
112204
return await self.post_async(shared, pr, None)
113205

114206

115-
from .llm import LLMClient
116-
from .rag import RAGNode
207+
from .llm import LLMClient # noqa: E402 - avoid circular import with .llm
208+
from .rag import RAGNode # noqa: E402 - avoid circular import with .rag
117209

118210
__all__ = [
119-
"BaseNode", "Node", "BatchNode",
120-
"Flow", "BatchFlow",
121-
"AsyncNode", "AsyncBatchNode", "AsyncParallelBatchNode",
122-
"AsyncFlow", "AsyncBatchFlow", "AsyncParallelBatchFlow",
123-
"LLMClient", "RAGNode",
211+
"BaseNode",
212+
"Node",
213+
"BatchNode",
214+
"Flow",
215+
"BatchFlow",
216+
"AsyncNode",
217+
"AsyncBatchNode",
218+
"AsyncParallelBatchNode",
219+
"AsyncFlow",
220+
"AsyncBatchFlow",
221+
"AsyncParallelBatchFlow",
222+
"LLMClient",
223+
"RAGNode",
124224
]

src/fenn/agents/__init__.pyi

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import asyncio
2-
from typing import Any, Dict, Iterator, List, Optional, Union, TypeVar, Generic
1+
from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar, Union
32

4-
_PrepResult = TypeVar('_PrepResult')
5-
_ExecResult = TypeVar('_ExecResult')
6-
_PostResult = TypeVar('_PostResult')
3+
_PrepResult = TypeVar("_PrepResult")
4+
_ExecResult = TypeVar("_ExecResult")
5+
_PostResult = TypeVar("_PostResult")
76

87
ParamValue = Union[str, int, float, bool, None, List[Any], Dict[str, Any]]
98
SharedData = Dict[str, Any]
@@ -18,7 +17,9 @@ class BaseNode(Generic[_PrepResult, _ExecResult, _PostResult]):
1817
def set_params(self, params: Params) -> None: ...
1918
def prep(self, shared: SharedData) -> _PrepResult: ...
2019
def exec(self, prep_res: _PrepResult) -> _ExecResult: ...
21-
def post(self, shared: SharedData, prep_res: _PrepResult, exec_res: _ExecResult) -> _PostResult: ...
20+
def post(
21+
self, shared: SharedData, prep_res: _PrepResult, exec_res: _ExecResult
22+
) -> _PostResult: ...
2223
def _exec(self, prep_res: _PrepResult) -> _ExecResult: ...
2324
def _run(self, shared: SharedData) -> _PostResult: ...
2425
def run(self, shared: SharedData) -> _PostResult: ...
@@ -51,36 +52,60 @@ class Flow(BaseNode[_PrepResult, Any, _PostResult]):
5152
) -> Optional[BaseNode[Any, Any, Any]]: ...
5253
def _orch(self, shared: SharedData, params: Optional[Params] = None) -> Any: ...
5354
def _run(self, shared: SharedData) -> _PostResult: ...
54-
def post(self, shared: SharedData, prep_res: _PrepResult, exec_res: Any) -> _PostResult: ...
55+
def post(
56+
self, shared: SharedData, prep_res: _PrepResult, exec_res: Any
57+
) -> _PostResult: ...
5558

5659
class BatchFlow(Flow[Optional[List[Params]], Any, _PostResult]):
5760
def _run(self, shared: SharedData) -> _PostResult: ...
5861

5962
class AsyncNode(Node[_PrepResult, _ExecResult, _PostResult]):
6063
async def prep_async(self, shared: SharedData) -> _PrepResult: ...
6164
async def exec_async(self, prep_res: _PrepResult) -> _ExecResult: ...
62-
async def exec_fallback_async(self, prep_res: _PrepResult, exc: Exception) -> _ExecResult: ...
63-
async def post_async(self, shared: SharedData, prep_res: _PrepResult, exec_res: _ExecResult) -> _PostResult: ...
65+
async def exec_fallback_async(
66+
self, prep_res: _PrepResult, exc: Exception
67+
) -> _ExecResult: ...
68+
async def post_async(
69+
self, shared: SharedData, prep_res: _PrepResult, exec_res: _ExecResult
70+
) -> _PostResult: ...
6471
async def _exec(self, prep_res: _PrepResult) -> _ExecResult: ...
6572
async def run_async(self, shared: SharedData) -> _PostResult: ...
6673
async def _run_async(self, shared: SharedData) -> _PostResult: ...
6774
def _run(self, shared: SharedData) -> _PostResult: ...
6875

69-
class AsyncBatchNode(AsyncNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult], BatchNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult]):
76+
class AsyncBatchNode(
77+
AsyncNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult],
78+
BatchNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult],
79+
):
7080
async def _exec(self, items: Optional[List[_PrepResult]]) -> List[_ExecResult]: ...
7181

72-
class AsyncParallelBatchNode(AsyncNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult], BatchNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult]):
82+
class AsyncParallelBatchNode(
83+
AsyncNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult],
84+
BatchNode[Optional[List[_PrepResult]], List[_ExecResult], _PostResult],
85+
):
7386
async def _exec(self, items: Optional[List[_PrepResult]]) -> List[_ExecResult]: ...
7487

75-
class AsyncFlow(Flow[_PrepResult, Any, _PostResult], AsyncNode[_PrepResult, Any, _PostResult]):
76-
async def _orch_async(self, shared: SharedData, params: Optional[Params] = None) -> Any: ...
88+
class AsyncFlow(
89+
Flow[_PrepResult, Any, _PostResult], AsyncNode[_PrepResult, Any, _PostResult]
90+
):
91+
async def _orch_async(
92+
self, shared: SharedData, params: Optional[Params] = None
93+
) -> Any: ...
7794
async def _run_async(self, shared: SharedData) -> _PostResult: ...
78-
async def post_async(self, shared: SharedData, prep_res: _PrepResult, exec_res: Any) -> _PostResult: ...
79-
80-
class AsyncBatchFlow(AsyncFlow[Optional[List[Params]], Any, _PostResult], BatchFlow[Optional[List[Params]], Any, _PostResult]):
95+
async def post_async(
96+
self, shared: SharedData, prep_res: _PrepResult, exec_res: Any
97+
) -> _PostResult: ...
98+
99+
class AsyncBatchFlow(
100+
AsyncFlow[Optional[List[Params]], Any, _PostResult],
101+
BatchFlow[Optional[List[Params]], Any, _PostResult],
102+
):
81103
async def _run_async(self, shared: SharedData) -> _PostResult: ...
82104

83-
class AsyncParallelBatchFlow(AsyncFlow[Optional[List[Params]], Any, _PostResult], BatchFlow[Optional[List[Params]], Any, _PostResult]):
105+
class AsyncParallelBatchFlow(
106+
AsyncFlow[Optional[List[Params]], Any, _PostResult],
107+
BatchFlow[Optional[List[Params]], Any, _PostResult],
108+
):
84109
async def _run_async(self, shared: SharedData) -> _PostResult: ...
85110

86111
class LLMClient:

0 commit comments

Comments
 (0)