|
1 | | -import asyncio, warnings, copy, time |
| 1 | +import asyncio |
| 2 | +import copy |
| 3 | +import time |
| 4 | +import warnings |
2 | 5 |
|
3 | 6 | _TERMINAL = object() # sentinel for explicit terminal transitions in Flow.connect |
4 | 7 |
|
5 | 8 |
|
6 | 9 | class BaseNode: |
7 | | - def __init__(self): self.params, self.successors = {}, {} |
8 | | - def set_params(self, params): self.params = params |
9 | | - def prep(self, shared): pass |
10 | | - def exec(self, prep_res): pass |
11 | | - def post(self, shared, prep_res, exec_res): pass |
12 | | - def _exec(self, prep_res): return self.exec(prep_res) |
13 | | - def _run(self, shared): p = self.prep(shared); e = self._exec(p); return self.post(shared, p, e) |
| 10 | + def __init__(self): |
| 11 | + self.params, self.successors = {}, {} |
| 12 | + |
| 13 | + def set_params(self, params): |
| 14 | + self.params = params |
| 15 | + |
| 16 | + def prep(self, shared): |
| 17 | + pass |
| 18 | + |
| 19 | + def exec(self, prep_res): |
| 20 | + pass |
| 21 | + |
| 22 | + def post(self, shared, prep_res, exec_res): |
| 23 | + pass |
| 24 | + |
| 25 | + def _exec(self, prep_res): |
| 26 | + return self.exec(prep_res) |
| 27 | + |
| 28 | + def _run(self, shared): |
| 29 | + p = self.prep(shared) |
| 30 | + e = self._exec(p) |
| 31 | + return self.post(shared, p, e) |
| 32 | + |
14 | 33 | def run(self, shared): |
15 | | - if self.successors: warnings.warn("Node won't run successors. Use Flow.") |
| 34 | + if self.successors: |
| 35 | + warnings.warn("Node won't run successors. Use Flow.") |
16 | 36 | return self._run(shared) |
17 | 37 |
|
18 | 38 |
|
19 | 39 | class Node(BaseNode): |
20 | | - def __init__(self, max_retries=1, wait=0): super().__init__(); self.max_retries, self.wait = max_retries, wait |
21 | | - def exec_fallback(self, prep_res, exc): raise exc |
| 40 | + def __init__(self, max_retries=1, wait=0): |
| 41 | + super().__init__() |
| 42 | + self.max_retries, self.wait = max_retries, wait |
| 43 | + |
| 44 | + def exec_fallback(self, prep_res, exc): |
| 45 | + raise exc |
| 46 | + |
22 | 47 | def _exec(self, prep_res): |
23 | 48 | for self.cur_retry in range(self.max_retries): |
24 | | - try: return self.exec(prep_res) |
| 49 | + try: |
| 50 | + return self.exec(prep_res) |
25 | 51 | except Exception as e: |
26 | | - if self.cur_retry == self.max_retries - 1: return self.exec_fallback(prep_res, e) |
27 | | - if self.wait > 0: time.sleep(self.wait) |
| 52 | + if self.cur_retry == self.max_retries - 1: |
| 53 | + return self.exec_fallback(prep_res, e) |
| 54 | + if self.wait > 0: |
| 55 | + time.sleep(self.wait) |
28 | 56 |
|
29 | 57 |
|
30 | 58 | class BatchNode(Node): |
31 | | - def _exec(self, items): return [super(BatchNode, self)._exec(i) for i in (items or [])] |
| 59 | + def _exec(self, items): |
| 60 | + return [super(BatchNode, self)._exec(i) for i in (items or [])] |
32 | 61 |
|
33 | 62 |
|
34 | 63 | class Flow(BaseNode): |
35 | | - def __init__(self, start=None): super().__init__(); self.start_node = start |
| 64 | + def __init__(self, start=None): |
| 65 | + super().__init__() |
| 66 | + self.start_node = start |
36 | 67 |
|
37 | | - def start(self, start): self.start_node = start; return start |
| 68 | + def start(self, start): |
| 69 | + self.start_node = start |
| 70 | + return start |
38 | 71 |
|
39 | 72 | def connect(self, src, dst, action="default"): |
40 | | - if action in src.successors: warnings.warn(f"Overwriting successor for action '{action}'") |
| 73 | + if action in src.successors: |
| 74 | + warnings.warn(f"Overwriting successor for action '{action}'") |
41 | 75 | src.successors[action] = _TERMINAL if dst is None else dst |
42 | 76 | return self |
43 | 77 |
|
44 | 78 | def get_next_node(self, curr, action): |
45 | 79 | nxt = curr.successors.get(action or "default") |
46 | | - if nxt is _TERMINAL: return None |
47 | | - if nxt is None and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}") |
| 80 | + if nxt is _TERMINAL: |
| 81 | + return None |
| 82 | + if nxt is None and curr.successors: |
| 83 | + warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}") |
48 | 84 | return nxt |
49 | 85 |
|
50 | 86 | def _orch(self, shared, params=None): |
51 | | - curr, p, last_action = copy.copy(self.start_node), (params or {**self.params}), None |
52 | | - while curr: curr.set_params(p); last_action = curr._run(shared); curr = copy.copy(self.get_next_node(curr, last_action)) |
| 87 | + curr, p, last_action = ( |
| 88 | + copy.copy(self.start_node), |
| 89 | + (params or {**self.params}), |
| 90 | + None, |
| 91 | + ) |
| 92 | + while curr: |
| 93 | + curr.set_params(p) |
| 94 | + last_action = curr._run(shared) |
| 95 | + curr = copy.copy(self.get_next_node(curr, last_action)) |
53 | 96 | return last_action |
54 | 97 |
|
55 | | - def _run(self, shared): p = self.prep(shared); o = self._orch(shared); return self.post(shared, p, o) |
56 | | - def post(self, shared, prep_res, exec_res): return exec_res |
| 98 | + def _run(self, shared): |
| 99 | + p = self.prep(shared) |
| 100 | + o = self._orch(shared) |
| 101 | + return self.post(shared, p, o) |
| 102 | + |
| 103 | + def post(self, shared, prep_res, exec_res): |
| 104 | + return exec_res |
57 | 105 |
|
58 | 106 |
|
59 | 107 | class BatchFlow(Flow): |
60 | 108 | def _run(self, shared): |
61 | 109 | pr = self.prep(shared) or [] |
62 | | - for bp in pr: self._orch(shared, {**self.params, **bp}) |
| 110 | + for bp in pr: |
| 111 | + self._orch(shared, {**self.params, **bp}) |
63 | 112 | return self.post(shared, pr, None) |
64 | 113 |
|
65 | 114 |
|
66 | 115 | class AsyncNode(Node): |
67 | | - async def prep_async(self, shared): pass |
68 | | - async def exec_async(self, prep_res): pass |
69 | | - async def exec_fallback_async(self, prep_res, exc): raise exc |
70 | | - async def post_async(self, shared, prep_res, exec_res): pass |
| 116 | + async def prep_async(self, shared): |
| 117 | + pass |
| 118 | + |
| 119 | + async def exec_async(self, prep_res): |
| 120 | + pass |
| 121 | + |
| 122 | + async def exec_fallback_async(self, prep_res, exc): |
| 123 | + raise exc |
| 124 | + |
| 125 | + async def post_async(self, shared, prep_res, exec_res): |
| 126 | + pass |
| 127 | + |
71 | 128 | async def _exec(self, prep_res): |
72 | 129 | for self.cur_retry in range(self.max_retries): |
73 | | - try: return await self.exec_async(prep_res) |
| 130 | + try: |
| 131 | + return await self.exec_async(prep_res) |
74 | 132 | except Exception as e: |
75 | | - if self.cur_retry == self.max_retries - 1: return await self.exec_fallback_async(prep_res, e) |
76 | | - if self.wait > 0: await asyncio.sleep(self.wait) |
| 133 | + if self.cur_retry == self.max_retries - 1: |
| 134 | + return await self.exec_fallback_async(prep_res, e) |
| 135 | + if self.wait > 0: |
| 136 | + await asyncio.sleep(self.wait) |
| 137 | + |
77 | 138 | async def run_async(self, shared): |
78 | | - if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.") |
| 139 | + if self.successors: |
| 140 | + warnings.warn("Node won't run successors. Use AsyncFlow.") |
79 | 141 | return await self._run_async(shared) |
80 | | - async def _run_async(self, shared): p = await self.prep_async(shared); e = await self._exec(p); return await self.post_async(shared, p, e) |
81 | | - def _run(self, shared): raise RuntimeError("Use run_async.") |
| 142 | + |
| 143 | + async def _run_async(self, shared): |
| 144 | + p = await self.prep_async(shared) |
| 145 | + e = await self._exec(p) |
| 146 | + return await self.post_async(shared, p, e) |
| 147 | + |
| 148 | + def _run(self, shared): |
| 149 | + raise RuntimeError("Use run_async.") |
82 | 150 |
|
83 | 151 |
|
84 | 152 | class AsyncBatchNode(AsyncNode, BatchNode): |
85 | | - async def _exec(self, items): return [await super(AsyncBatchNode, self)._exec(i) for i in items] |
| 153 | + async def _exec(self, items): |
| 154 | + return [await super(AsyncBatchNode, self)._exec(i) for i in items] |
86 | 155 |
|
87 | 156 |
|
88 | 157 | class AsyncParallelBatchNode(AsyncNode, BatchNode): |
89 | | - async def _exec(self, items): return await asyncio.gather(*(super(AsyncParallelBatchNode, self)._exec(i) for i in items)) |
| 158 | + async def _exec(self, items): |
| 159 | + return await asyncio.gather( |
| 160 | + *(super(AsyncParallelBatchNode, self)._exec(i) for i in items) |
| 161 | + ) |
90 | 162 |
|
91 | 163 |
|
92 | 164 | class AsyncFlow(Flow, AsyncNode): |
93 | 165 | async def _orch_async(self, shared, params=None): |
94 | | - curr, p, last_action = copy.copy(self.start_node), (params or {**self.params}), None |
95 | | - while curr: curr.set_params(p); last_action = await curr._run_async(shared) if isinstance(curr, AsyncNode) else curr._run(shared); curr = copy.copy(self.get_next_node(curr, last_action)) |
| 166 | + curr, p, last_action = ( |
| 167 | + copy.copy(self.start_node), |
| 168 | + (params or {**self.params}), |
| 169 | + None, |
| 170 | + ) |
| 171 | + while curr: |
| 172 | + curr.set_params(p) |
| 173 | + last_action = ( |
| 174 | + await curr._run_async(shared) |
| 175 | + if isinstance(curr, AsyncNode) |
| 176 | + else curr._run(shared) |
| 177 | + ) |
| 178 | + curr = copy.copy(self.get_next_node(curr, last_action)) |
96 | 179 | return last_action |
97 | | - async def _run_async(self, shared): p = await self.prep_async(shared); o = await self._orch_async(shared); return await self.post_async(shared, p, o) |
98 | | - async def post_async(self, shared, prep_res, exec_res): return exec_res |
| 180 | + |
| 181 | + async def _run_async(self, shared): |
| 182 | + p = await self.prep_async(shared) |
| 183 | + o = await self._orch_async(shared) |
| 184 | + return await self.post_async(shared, p, o) |
| 185 | + |
| 186 | + async def post_async(self, shared, prep_res, exec_res): |
| 187 | + return exec_res |
99 | 188 |
|
100 | 189 |
|
101 | 190 | class AsyncBatchFlow(AsyncFlow, BatchFlow): |
102 | 191 | async def _run_async(self, shared): |
103 | 192 | pr = await self.prep_async(shared) or [] |
104 | | - for bp in pr: await self._orch_async(shared, {**self.params, **bp}) |
| 193 | + for bp in pr: |
| 194 | + await self._orch_async(shared, {**self.params, **bp}) |
105 | 195 | return await self.post_async(shared, pr, None) |
106 | 196 |
|
107 | 197 |
|
108 | 198 | class AsyncParallelBatchFlow(AsyncFlow, BatchFlow): |
109 | 199 | async def _run_async(self, shared): |
110 | 200 | pr = await self.prep_async(shared) or [] |
111 | | - await asyncio.gather(*(self._orch_async(shared, {**self.params, **bp}) for bp in pr)) |
| 201 | + await asyncio.gather( |
| 202 | + *(self._orch_async(shared, {**self.params, **bp}) for bp in pr) |
| 203 | + ) |
112 | 204 | return await self.post_async(shared, pr, None) |
113 | 205 |
|
114 | 206 |
|
115 | | -from .llm import LLMClient |
116 | | -from .rag import RAGNode |
| 207 | +from .llm import LLMClient # noqa: E402 - avoid circular import with .llm |
| 208 | +from .rag import RAGNode # noqa: E402 - avoid circular import with .rag |
117 | 209 |
|
118 | 210 | __all__ = [ |
119 | | - "BaseNode", "Node", "BatchNode", |
120 | | - "Flow", "BatchFlow", |
121 | | - "AsyncNode", "AsyncBatchNode", "AsyncParallelBatchNode", |
122 | | - "AsyncFlow", "AsyncBatchFlow", "AsyncParallelBatchFlow", |
123 | | - "LLMClient", "RAGNode", |
| 211 | + "BaseNode", |
| 212 | + "Node", |
| 213 | + "BatchNode", |
| 214 | + "Flow", |
| 215 | + "BatchFlow", |
| 216 | + "AsyncNode", |
| 217 | + "AsyncBatchNode", |
| 218 | + "AsyncParallelBatchNode", |
| 219 | + "AsyncFlow", |
| 220 | + "AsyncBatchFlow", |
| 221 | + "AsyncParallelBatchFlow", |
| 222 | + "LLMClient", |
| 223 | + "RAGNode", |
124 | 224 | ] |
0 commit comments