Skip to content

Commit 5b1fd1d

Browse files
Merge pull request #518 from MervinPraison/claude/issue-484-20250527_043953
feat: Add context management for workflow loops
2 parents bb91873 + 73e38fa commit 5b1fd1d

3 files changed

Lines changed: 163 additions & 39 deletions

File tree

src/praisonai-agents/praisonaiagents/process/process.py

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,44 @@ def __init__(self, tasks: Dict[str, Task], agents: List[Agent], manager_llm: Opt
3030
self.task_retry_counter: Dict[str, int] = {} # Initialize retry counter
3131
self.workflow_finished = False # ADDED: Workflow finished flag
3232

33+
def _build_task_context(self, current_task: Task) -> str:
34+
"""Build context for a task based on its retain_full_context setting"""
35+
if not (current_task.previous_tasks or current_task.context):
36+
return ""
37+
38+
context = "\nInput data from previous tasks:"
39+
40+
if current_task.retain_full_context:
41+
# Original behavior: include all previous tasks
42+
for prev_name in current_task.previous_tasks:
43+
prev_task = next((t for t in self.tasks.values() if t.name == prev_name), None)
44+
if prev_task and prev_task.result:
45+
context += f"\n{prev_name}: {prev_task.result.raw}"
46+
47+
# Add data from context tasks
48+
if current_task.context:
49+
for ctx_task in current_task.context:
50+
if ctx_task.result and ctx_task.name != current_task.name:
51+
context += f"\n{ctx_task.name}: {ctx_task.result.raw}"
52+
else:
53+
# New behavior: only include the most recent previous task
54+
if current_task.previous_tasks:
55+
# Get the most recent previous task (last in the list)
56+
prev_name = current_task.previous_tasks[-1]
57+
prev_task = next((t for t in self.tasks.values() if t.name == prev_name), None)
58+
if prev_task and prev_task.result:
59+
context += f"\n{prev_name}: {prev_task.result.raw}"
60+
61+
# For context tasks, still include the most recent one
62+
if current_task.context:
63+
# Get the most recent context task with a result
64+
for ctx_task in reversed(current_task.context):
65+
if ctx_task.result and ctx_task.name != current_task.name:
66+
context += f"\n{ctx_task.name}: {ctx_task.result.raw}"
67+
break # Only include the most recent one
68+
69+
return context
70+
3371
def _find_next_not_started_task(self) -> Optional[Task]:
3472
"""Fallback mechanism to find the next 'not started' task."""
3573
fallback_attempts = 0
@@ -147,25 +185,8 @@ async def aworkflow(self) -> AsyncGenerator[str, None]:
147185
""")
148186

149187
# Add context from previous tasks to description
150-
if current_task.previous_tasks or current_task.context:
151-
context = "\nInput data from previous tasks:"
152-
153-
# Add data from previous tasks in workflow
154-
for prev_name in current_task.previous_tasks:
155-
prev_task = next((t for t in self.tasks.values() if t.name == prev_name), None)
156-
if prev_task and prev_task.result:
157-
# Handle loop data
158-
if current_task.task_type == "loop":
159-
context += f"\n{prev_name}: {prev_task.result.raw}"
160-
else:
161-
context += f"\n{prev_name}: {prev_task.result.raw}"
162-
163-
# Add data from context tasks
164-
if current_task.context:
165-
for ctx_task in current_task.context:
166-
if ctx_task.result and ctx_task.name != current_task.name:
167-
context += f"\n{ctx_task.name}: {ctx_task.result.raw}"
168-
188+
context = self._build_task_context(current_task)
189+
if context:
169190
# Update task description with context
170191
current_task.description = current_task.description + context
171192

@@ -778,25 +799,8 @@ def workflow(self):
778799
""")
779800

780801
# Add context from previous tasks to description
781-
if current_task.previous_tasks or current_task.context:
782-
context = "\nInput data from previous tasks:"
783-
784-
# Add data from previous tasks in workflow
785-
for prev_name in current_task.previous_tasks:
786-
prev_task = next((t for t in self.tasks.values() if t.name == prev_name), None)
787-
if prev_task and prev_task.result:
788-
# Handle loop data
789-
if current_task.task_type == "loop":
790-
context += f"\n{prev_name}: {prev_task.result.raw}"
791-
else:
792-
context += f"\n{prev_name}: {prev_task.result.raw}"
793-
794-
# Add data from context tasks
795-
if current_task.context:
796-
for ctx_task in current_task.context:
797-
if ctx_task.result and ctx_task.name != current_task.name:
798-
context += f"\n{ctx_task.name}: {ctx_task.result.raw}"
799-
802+
context = self._build_task_context(current_task)
803+
if context:
800804
# Update task description with context
801805
current_task.description = current_task.description + context
802806

src/praisonai-agents/praisonaiagents/task/task.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ def __init__(
3939
memory=None,
4040
quality_check=True,
4141
input_file: Optional[str] = None,
42-
rerun: bool = False # Renamed from can_rerun and logic inverted, default True for backward compatibility
42+
rerun: bool = False, # Renamed from can_rerun and logic inverted, default True for backward compatibility
43+
retain_full_context: bool = False # By default, only use previous task output, not all previous tasks
4344
):
4445
# Add check if memory config is provided
4546
if memory is not None or (config and config.get('memory_config')):
@@ -78,6 +79,7 @@ def __init__(
7879
self.memory = memory
7980
self.quality_check = quality_check
8081
self.rerun = rerun # Assigning the rerun parameter
82+
self.retain_full_context = retain_full_context
8183

8284
# Set logger level based on config verbose level
8385
verbose = self.config.get("verbose", 0)
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Test script to verify context management functionality
4+
"""
5+
6+
import sys
7+
import os
8+
sys.path.append(os.path.join(os.path.dirname(__file__), 'src', 'praisonai-agents'))
9+
10+
from praisonaiagents.task.task import Task
11+
from praisonaiagents.agent.agent import Agent
12+
from praisonaiagents.process.process import Process
13+
from praisonaiagents.main import TaskOutput
14+
15+
def test_context_management():
16+
"""Test context management with and without retain_full_context"""
17+
18+
print("Starting context management test...")
19+
20+
try:
21+
# Create a mock agent
22+
agent = Agent(name="Test Agent", role="Tester", goal="Test context management")
23+
print("✓ Created test agent")
24+
25+
# Create tasks with results
26+
task1 = Task(
27+
name="task1",
28+
description="First task",
29+
agent=agent,
30+
status="completed"
31+
)
32+
task1.result = TaskOutput(
33+
description="First task",
34+
raw="Result from task 1",
35+
agent="Test Agent"
36+
)
37+
print("✓ Created task1")
38+
39+
task2 = Task(
40+
name="task2",
41+
description="Second task",
42+
agent=agent,
43+
status="completed"
44+
)
45+
task2.result = TaskOutput(
46+
description="Second task",
47+
raw="Result from task 2",
48+
agent="Test Agent"
49+
)
50+
# Set up the previous_tasks manually since it's not a constructor parameter
51+
task2.previous_tasks = ["task1"]
52+
print("✓ Created task2")
53+
54+
# Test case 1: Default behavior (retain_full_context=False)
55+
task3_limited = Task(
56+
name="task3_limited",
57+
description="Third task with limited context",
58+
agent=agent,
59+
retain_full_context=False # Default behavior
60+
)
61+
# Set up the previous_tasks manually
62+
task3_limited.previous_tasks = ["task1", "task2"]
63+
print("✓ Created task3_limited")
64+
65+
# Test case 2: Full context retention (retain_full_context=True)
66+
task3_full = Task(
67+
name="task3_full",
68+
description="Third task with full context",
69+
agent=agent,
70+
retain_full_context=True # Original behavior
71+
)
72+
# Set up the previous_tasks manually
73+
task3_full.previous_tasks = ["task1", "task2"]
74+
print("✓ Created task3_full")
75+
76+
# Create process and test context building
77+
tasks_dict = {
78+
task1.id: task1,
79+
task2.id: task2,
80+
task3_limited.id: task3_limited,
81+
task3_full.id: task3_full
82+
}
83+
84+
process = Process(tasks=tasks_dict, agents=[agent])
85+
print("✓ Created process")
86+
87+
# Test limited context
88+
limited_context = process._build_task_context(task3_limited)
89+
print("Limited context (retain_full_context=False):")
90+
print(f"'{limited_context}'")
91+
print()
92+
93+
# Test full context
94+
full_context = process._build_task_context(task3_full)
95+
print("Full context (retain_full_context=True):")
96+
print(f"'{full_context}'")
97+
print()
98+
99+
# Verify results
100+
assert "task2" in limited_context, "Limited context should include most recent task"
101+
assert "task1" not in limited_context, "Limited context should NOT include earlier tasks"
102+
103+
assert "task1" in full_context, "Full context should include all previous tasks"
104+
assert "task2" in full_context, "Full context should include all previous tasks"
105+
106+
print("✅ All tests passed!")
107+
print("- Limited context only includes most recent previous task")
108+
print("- Full context includes all previous tasks")
109+
print("- Backwards compatibility maintained")
110+
111+
except Exception as e:
112+
print(f"❌ Test failed with error: {e}")
113+
import traceback
114+
traceback.print_exc()
115+
raise
116+
117+
if __name__ == "__main__":
118+
test_context_management()

0 commit comments

Comments
 (0)