Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions pocketflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ def __init__(self,src,action): self.src,self.action=src,action
def __rshift__(self,tgt): return self.src.next(tgt,self.action)

class Node(BaseNode):
def __init__(self,max_retries=1,wait=0): super().__init__(); self.max_retries,self.wait=max_retries,wait
def __init__(self,max_retries=1,wait=0):
if max_retries < 1:
raise ValueError("max_retries must be at least 1")
if wait < 0:
raise ValueError("wait must be non-negative")
super().__init__(); self.max_retries,self.wait=max_retries,wait
def exec_fallback(self,prep_res,exc): raise exc
def _exec(self,prep_res):
for self.cur_retry in range(self.max_retries):
Expand All @@ -40,9 +45,11 @@ class Flow(BaseNode):
def __init__(self,start=None): super().__init__(); self.start_node=start
def start(self,start): self.start_node=start; return start
def get_next_node(self,curr,action):
nxt=curr.successors.get(action or "default")
if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
return nxt
key=action or "default"
if key not in curr.successors:
if curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}")
return None
return curr.successors[key]
def _orch(self,shared,params=None):
curr,p,last_action =copy.copy(self.start_node),(params or {**self.params}),None
while curr: curr.set_params(p); last_action=curr._run(shared); curr=copy.copy(self.get_next_node(curr,last_action))
Expand Down Expand Up @@ -74,10 +81,10 @@ async def _run_async(self,shared): p=await self.prep_async(shared); e=await self
def _run(self,shared): raise RuntimeError("Use run_async.")

class AsyncBatchNode(AsyncNode,BatchNode):
async def _exec(self,items): return [await super(AsyncBatchNode,self)._exec(i) for i in items]
async def _exec(self,items): return [await super(AsyncBatchNode,self)._exec(i) for i in (items or [])]

class AsyncParallelBatchNode(AsyncNode,BatchNode):
async def _exec(self,items): return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items))
async def _exec(self,items): return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in (items or [])))

class AsyncFlow(Flow,AsyncNode):
async def _orch_async(self,shared,params=None):
Expand Down
150 changes: 150 additions & 0 deletions tests/test_async_batch_none_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""Tests for AsyncBatchNode/AsyncParallelBatchNode handling of None prep results."""

import unittest
import asyncio
from pocketflow import AsyncBatchNode, AsyncParallelBatchNode, BatchNode


class TestBatchNodeNoneHandling(unittest.TestCase):
"""Verify synchronous BatchNode handles None gracefully (baseline)."""

def test_batch_node_none_returns_empty(self):
"""BatchNode with None prep result should return []."""

class MyBatch(BatchNode):
def exec(self, prep_res):
return f"processed:{prep_res}"

def post(self, shared, prep_res, exec_res):
return exec_res

node = MyBatch()
# prep returns None by default, _exec gets None
result = node._run({})
self.assertEqual(result, [])

def test_batch_node_empty_list_returns_empty(self):
"""BatchNode with empty list should return []."""

class MyBatch(BatchNode):
def prep(self, shared):
return shared.get("items", [])

def exec(self, prep_res):
return f"processed:{prep_res}"

def post(self, shared, prep_res, exec_res):
return exec_res

node = MyBatch()
result = node.run({"items": []})
self.assertEqual(result, [])


class TestAsyncBatchNodeNoneHandling(unittest.TestCase):
"""AsyncBatchNode must handle None prep results like BatchNode does."""

def test_async_batch_node_none_returns_empty(self):
"""AsyncBatchNode with None prep result should return [], not crash."""

class MyAsyncBatch(AsyncBatchNode):
async def exec_async(self, prep_res):
return f"processed:{prep_res}"

async def post_async(self, shared, prep_res, exec_res):
return exec_res

node = MyAsyncBatch()
result = asyncio.run(node._run_async({}))
self.assertEqual(result, [])

def test_async_batch_node_empty_list_returns_empty(self):
"""AsyncBatchNode with empty list should return []."""

class MyAsyncBatch(AsyncBatchNode):
async def prep_async(self, shared):
return shared.get("items", [])

async def exec_async(self, prep_res):
return f"processed:{prep_res}"

async def post_async(self, shared, prep_res, exec_res):
return exec_res

node = MyAsyncBatch()
result = asyncio.run(node.run_async({"items": []}))
self.assertEqual(result, [])

def test_async_batch_node_with_data(self):
"""AsyncBatchNode should still work correctly with actual data."""

class MyAsyncBatch(AsyncBatchNode):
async def prep_async(self, shared):
return shared.get("items", [])

async def exec_async(self, prep_res):
return f"processed:{prep_res}"

async def post_async(self, shared, prep_res, exec_res):
return exec_res

node = MyAsyncBatch()
result = asyncio.run(node.run_async({"items": ["a", "b", "c"]}))
self.assertEqual(result, ["processed:a", "processed:b", "processed:c"])


class TestAsyncParallelBatchNodeNoneHandling(unittest.TestCase):
"""AsyncParallelBatchNode must handle None prep results like BatchNode does."""

def test_async_parallel_batch_node_none_returns_empty(self):
"""AsyncParallelBatchNode with None prep result should return [], not crash."""

class MyAsyncParallelBatch(AsyncParallelBatchNode):
async def exec_async(self, prep_res):
return f"processed:{prep_res}"

async def post_async(self, shared, prep_res, exec_res):
return exec_res

node = MyAsyncParallelBatch()
result = asyncio.run(node._run_async({}))
self.assertEqual(result, [])

def test_async_parallel_batch_node_empty_list_returns_empty(self):
"""AsyncParallelBatchNode with empty list should return []."""

class MyAsyncParallelBatch(AsyncParallelBatchNode):
async def prep_async(self, shared):
return shared.get("items", [])

async def exec_async(self, prep_res):
return f"processed:{prep_res}"

async def post_async(self, shared, prep_res, exec_res):
return exec_res

node = MyAsyncParallelBatch()
result = asyncio.run(node.run_async({"items": []}))
self.assertEqual(result, [])

def test_async_parallel_batch_node_with_data(self):
"""AsyncParallelBatchNode should still work correctly with actual data."""

class MyAsyncParallelBatch(AsyncParallelBatchNode):
async def prep_async(self, shared):
return shared.get("items", [])

async def exec_async(self, prep_res):
return f"processed:{prep_res}"

async def post_async(self, shared, prep_res, exec_res):
return exec_res

node = MyAsyncParallelBatch()
result = asyncio.run(node.run_async({"items": ["a", "b", "c"]}))
# Parallel execution may reorder, so compare as sets
self.assertEqual(set(result), {"processed:a", "processed:b", "processed:c"})


if __name__ == "__main__":
unittest.main()
124 changes: 124 additions & 0 deletions tests/test_flow_none_successor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Tests for Flow.get_next_node handling of None successors (issue #56)."""

import unittest
import warnings
from pocketflow import Flow, Node


class TestNoneSuccessor(unittest.TestCase):
""">> None should create a valid terminal successor without triggering warnings."""

def test_none_successor_no_warning(self):
"""node - 'end' >> None should not trigger a warning when 'end' is returned."""

class Decide(Node):
def exec(self, prep_res):
return "end"

def post(self, shared, prep_res, exec_res):
return exec_res

decide = Decide()
decide - "end" >> None

flow = Flow(start=decide)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = flow.run({})
flow_warnings = [x for x in w if "Flow ends" in str(x.message)]
self.assertEqual(len(flow_warnings), 0,
f"Unexpected warning: {[str(x.message) for x in flow_warnings]}")
self.assertEqual(result, "end")

def test_none_successor_with_other_branches(self):
"""Mixed successors: >> node and >> None should coexist."""

class Branch(Node):
def exec(self, prep_res):
return self.params.get("action", "end")

def post(self, shared, prep_res, exec_res):
return exec_res

class Process(Node):
def exec(self, prep_res):
return "processed"

def post(self, shared, prep_res, exec_res):
return exec_res

branch = Branch()
process = Process()
branch - "continue" >> process
branch - "end" >> None

# Test terminating branch
flow1 = Flow(start=branch)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result1 = flow1.run({"action": "end"})
flow_warnings = [x for x in w if "Flow ends" in str(x.message)]
self.assertEqual(len(flow_warnings), 0)
self.assertEqual(result1, "end")

# Test continuing branch
flow2 = Flow(start=branch)
flow2.params = {"action": "continue"}
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result2 = flow2.run({})
flow_warnings = [x for x in w if "Flow ends" in str(x.message)]
# 'processed' is returned by process, which has no successors —
# that's a legitimate end, no warning expected
self.assertEqual(len(flow_warnings), 0)
self.assertEqual(result2, "processed")

def test_unknown_action_still_warns(self):
"""Returning an action with no registered successor should still warn."""

class Decide(Node):
def exec(self, prep_res):
return "nonexistent"

def post(self, shared, prep_res, exec_res):
return exec_res

class Next(Node):
pass

decide = Decide()
next_node = Next()
decide - "expected" >> next_node

flow = Flow(start=decide)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
flow.run({})
flow_warnings = [x for x in w if "Flow ends" in str(x.message)]
self.assertEqual(len(flow_warnings), 1)
self.assertIn("nonexistent", str(flow_warnings[0].message))

def test_default_none_successor(self):
""">> None with default action should not warn."""

class End(Node):
def exec(self, prep_res):
return "default"

def post(self, shared, prep_res, exec_res):
return exec_res

end = End()
end >> None # default successor is None

flow = Flow(start=end)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = flow.run({})
flow_warnings = [x for x in w if "Flow ends" in str(x.message)]
self.assertEqual(len(flow_warnings), 0)
self.assertEqual(result, "default")


if __name__ == "__main__":
unittest.main()
Loading