Skip to content

Commit b9e32cc

Browse files
douglas-reidDouglas Reid
andauthored
fix(agents): generate proper history blocks in agent (#570)
With the inflight changes for streaming, we unfortunately merged a commit that left message selection in a broken state. The existing testing was not enough to capture the issue. This PR attempts to restore proper message selection functionality. Follow on PRs will be necessary to clean up and streamline the selection bits added in this PR. Co-authored-by: Douglas Reid <doug@steamship.com>
1 parent f659229 commit b9e32cc

13 files changed

Lines changed: 384 additions & 61 deletions

File tree

src/steamship/agents/examples/example_assistant.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1+
from typing import Type
2+
3+
from pydantic.fields import Field
4+
15
from steamship.agents.functional import FunctionsBasedAgent
26
from steamship.agents.llms.openai import ChatOpenAI
37
from steamship.agents.schema.message_selectors import MessageWindowMessageSelector
48
from steamship.agents.service.agent_service import AgentService
59
from steamship.agents.tools.image_generation import DalleTool
610
from steamship.agents.tools.search import SearchTool
11+
from steamship.invocable import Config
712
from steamship.utils.repl import AgentREPL
813

914

@@ -13,6 +18,13 @@ class MyFunctionsBasedAssistant(AgentService):
1318
to provide an overview of the types of tasks it can accomplish (here, search
1419
and image generation)."""
1520

21+
class AgentConfig(Config):
22+
model_name: str = Field(default="gpt-4")
23+
24+
@classmethod
25+
def config_cls(cls) -> Type[Config]:
26+
return MyFunctionsBasedAssistant.AgentConfig
27+
1628
def __init__(self, **kwargs):
1729
super().__init__(**kwargs)
1830
self.set_default_agent(
@@ -21,7 +33,7 @@ def __init__(self, **kwargs):
2133
SearchTool(),
2234
DalleTool(),
2335
],
24-
llm=ChatOpenAI(self.client, temperature=0),
36+
llm=ChatOpenAI(self.client, temperature=0, model_name=self.config.model_name),
2537
message_selector=MessageWindowMessageSelector(k=2),
2638
)
2739
)
@@ -31,4 +43,6 @@ def __init__(self, **kwargs):
3143
# AgentREPL provides a mechanism for local execution of an AgentService method.
3244
# This is used for simplified debugging as agents and tools are developed and
3345
# added.
34-
AgentREPL(MyFunctionsBasedAssistant, agent_package_config={}).run(dump_history_on_exit=True)
46+
AgentREPL(MyFunctionsBasedAssistant, agent_package_config={"model_name": "gpt-3.5-turbo"}).run(
47+
dump_history_on_exit=True
48+
)

src/steamship/agents/functional/functions_based.py

Lines changed: 78 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import json
2+
from operator import attrgetter
13
from typing import List
24

3-
from steamship import Block
5+
from steamship import Block, MimeTypes, Tag
46
from steamship.agents.functional.output_parser import FunctionsBasedOutputParser
5-
from steamship.agents.schema import Action, AgentContext, ChatAgent, ChatLLM, Tool
6-
from steamship.data.tags.tag_constants import RoleTag
7+
from steamship.agents.schema import Action, AgentContext, ChatAgent, ChatLLM, FinishAction, Tool
8+
from steamship.data.tags.tag_constants import ChatTag, RoleTag, TagKind, TagValueKey
9+
from steamship.data.tags.tag_utils import get_tag
710

811

912
class FunctionsBasedAgent(ChatAgent):
@@ -54,6 +57,8 @@ def build_chat_history_for_tool(self, context: AgentContext) -> List[Block]:
5457
# get most recent context
5558
messages_from_memory.extend(context.chat_history.select_messages(self.message_selector))
5659

60+
messages_from_memory.sort(key=attrgetter("index_in_file"))
61+
5762
# de-dupe the messages from memory
5863
ids = [context.chat_history.last_user_message.id]
5964
for msg in messages_from_memory:
@@ -67,10 +72,8 @@ def build_chat_history_for_tool(self, context: AgentContext) -> List[Block]:
6772
# this should happen BEFORE any agent/assistant messages related to tool selection
6873
messages.append(context.chat_history.last_user_message)
6974

70-
# get completed steps
71-
actions = context.completed_steps
72-
for action in actions:
73-
messages.extend(action.to_chat_messages())
75+
# get working history (completed actions)
76+
messages.extend(self._function_calls_since_last_user_message(context))
7477

7578
return messages
7679

@@ -81,4 +84,71 @@ def next_action(self, context: AgentContext) -> Action:
8184
# Run the default LLM on those messages
8285
output_blocks = self.llm.chat(messages=messages, tools=self.tools)
8386

84-
return self.output_parser.parse(output_blocks[0].text, context)
87+
future_action = self.output_parser.parse(output_blocks[0].text, context)
88+
if not isinstance(future_action, FinishAction):
89+
# record the LLM's function response in history
90+
self._record_action_selection(future_action, context)
91+
return future_action
92+
93+
def _function_calls_since_last_user_message(self, context: AgentContext) -> List[Block]:
94+
function_calls = []
95+
for block in context.chat_history.messages[::-1]: # is this too inefficient at scale?
96+
if block.chat_role == RoleTag.USER:
97+
return reversed(function_calls)
98+
if get_tag(block.tags, kind=TagKind.ROLE, name=RoleTag.FUNCTION):
99+
function_calls.append(block)
100+
elif get_tag(block.tags, kind=TagKind.FUNCTION_SELECTION):
101+
function_calls.append(block)
102+
return reversed(function_calls)
103+
104+
def _to_openai_function_selection(self, action: Action) -> str:
105+
"""NOTE: Temporary placeholder. Should be refactored"""
106+
fc = {"name": action.tool}
107+
args = {}
108+
for block in action.input:
109+
for t in block.tags:
110+
if t.kind == TagKind.FUNCTION_ARG:
111+
args[t.name] = block.as_llm_input(exclude_block_wrapper=True)
112+
113+
fc["arguments"] = json.dumps(args) # the arguments must be a string value NOT a dict
114+
return json.dumps(fc)
115+
116+
def _record_action_selection(self, action: Action, context: AgentContext):
117+
tags = [
118+
Tag(
119+
kind=TagKind.CHAT,
120+
name=ChatTag.ROLE,
121+
value={TagValueKey.STRING_VALUE: RoleTag.ASSISTANT},
122+
),
123+
Tag(kind=TagKind.FUNCTION_SELECTION, name=action.tool),
124+
]
125+
context.chat_history.file.append_block(
126+
text=self._to_openai_function_selection(action), tags=tags, mime_type=MimeTypes.TXT
127+
)
128+
129+
def record_action_run(self, action: Action, context: AgentContext):
130+
super().record_action_run(action, context)
131+
132+
if isinstance(action, FinishAction):
133+
return
134+
135+
tags = [
136+
Tag(
137+
kind=TagKind.ROLE,
138+
name=RoleTag.FUNCTION,
139+
value={TagValueKey.STRING_VALUE: action.tool},
140+
),
141+
# need the following tag for backwards compatibility with older gpt-4 plugin
142+
Tag(
143+
kind="name",
144+
name=action.tool,
145+
),
146+
]
147+
# TODO(dougreid): I'm not convinced this is correct for tools that return multiple values.
148+
# It _feels_ like these should be named and inlined as a single message in history, etc.
149+
for block in action.output:
150+
context.chat_history.file.append_block(
151+
text=block.as_llm_input(exclude_block_wrapper=True),
152+
tags=tags,
153+
mime_type=block.mime_type,
154+
)

src/steamship/agents/functional/output_parser.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from json import JSONDecodeError
55
from typing import Dict, List, Optional
66

7-
from steamship import Block, MimeTypes, Steamship
7+
from steamship import Block, MimeTypes, Steamship, Tag
88
from steamship.agents.schema import Action, AgentContext, FinishAction, OutputParser, Tool
9-
from steamship.data.tags.tag_constants import RoleTag
9+
from steamship.data.tags.tag_constants import RoleTag, TagKind
1010
from steamship.utils.utils import is_valid_uuid4
1111

1212

@@ -43,16 +43,45 @@ def _extract_action_from_function_call(self, text: str, context: AgentContext) -
4343
try:
4444
args = json.loads(arguments)
4545
if text := args.get("text"):
46-
input_blocks.append(Block(text=text, mime_type=MimeTypes.TXT))
46+
input_blocks.append(
47+
Block(
48+
text=text,
49+
tags=[Tag(kind=TagKind.FUNCTION_ARG, name="text")],
50+
mime_type=MimeTypes.TXT,
51+
)
52+
)
4753
elif uuid_arg := args.get("uuid"):
48-
input_blocks.append(Block.get(context.client, _id=uuid_arg))
54+
existing_block = Block.get(context.client, _id=uuid_arg)
55+
tag = Tag.create(
56+
existing_block.client,
57+
file_id=existing_block.file_id,
58+
block_id=existing_block.id,
59+
kind=TagKind.FUNCTION_ARG,
60+
name="uuid",
61+
)
62+
existing_block.tags.append(tag)
63+
input_blocks.append(existing_block)
4964
except json.decoder.JSONDecodeError:
5065
if isinstance(arguments, str):
5166
if is_valid_uuid4(arguments):
52-
input_blocks.append(Block.get(context.client, _id=uuid_arg))
67+
existing_block = Block.get(context.client, _id=arguments)
68+
tag = Tag.create(
69+
existing_block.client,
70+
file_id=existing_block.file_id,
71+
block_id=existing_block.id,
72+
kind=TagKind.FUNCTION_ARG,
73+
name="uuid",
74+
)
75+
existing_block.tags.append(tag)
76+
input_blocks.append(existing_block)
5377
else:
54-
input_blocks.append(Block(text=arguments, mime_type=MimeTypes.TXT))
55-
78+
input_blocks.append(
79+
Block(
80+
text=arguments,
81+
tags=[Tag(kind=TagKind.FUNCTION_ARG, name="text")],
82+
mime_type=MimeTypes.TXT,
83+
)
84+
)
5685
return Action(tool=tool.name, input=input_blocks, context=context)
5786

5887
@staticmethod

src/steamship/agents/schema/action.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22

33
from pydantic import BaseModel
44

5-
from steamship import Block, Tag
6-
from steamship.data import TagKind
7-
from steamship.data.tags.tag_constants import RoleTag
5+
from steamship import Block
86

97

108
class Action(BaseModel):
@@ -28,25 +26,35 @@ class Action(BaseModel):
2826
Setting this to True means that the executing Agent should halt any reasoning.
2927
"""
3028

31-
def to_chat_messages(self) -> List[Block]:
32-
tags = [
33-
Tag(kind=TagKind.ROLE, name=RoleTag.FUNCTION),
34-
Tag(kind="name", name=self.tool),
35-
]
36-
blocks = []
37-
for block in self.output:
38-
# TODO(dougreid): should we revisit as_llm_input? we might need only the UUID...
39-
blocks.append(
40-
Block(
41-
text=block.as_llm_input(exclude_block_wrapper=True),
42-
tags=tags,
43-
mime_type=block.mime_type,
44-
)
45-
)
46-
47-
# TODO(dougreid): revisit when have multiple output functions.
48-
# Current thinking: LLM will be OK with multiple function blocks in a row. NEEDS validation.
49-
return blocks
29+
# def to_chat_messages(self) -> List[Block]:
30+
# blocks = []
31+
# for arg in self.input:
32+
#
33+
#
34+
# blocks.append(
35+
# Block(
36+
# text=json.dumps({"name": f"{self.tool}", "arguments": "{ \"text\": \"who is the current president of Taiwan?\" }"}),
37+
# )
38+
# )
39+
#
40+
# tags = [
41+
# Tag(kind=TagKind.ROLE, name=RoleTag.FUNCTION),
42+
# Tag(kind="name", name=self.tool),
43+
# ]
44+
#
45+
# for block in self.output:
46+
# # TODO(dougreid): should we revisit as_llm_input? we might need only the UUID...
47+
# blocks.append(
48+
# Block(
49+
# text=block.as_llm_input(exclude_block_wrapper=True),
50+
# tags=tags,
51+
# mime_type=block.mime_type,
52+
# )
53+
# )
54+
#
55+
# # TODO(dougreid): revisit when have multiple output functions.
56+
# # Current thinking: LLM will be OK with multiple function blocks in a row. NEEDS validation.
57+
# return blocks
5058

5159

5260
class FinishAction(Action):

src/steamship/agents/schema/agent.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ class Agent(BaseModel, ABC):
3131
def next_action(self, context: AgentContext) -> Action:
3232
pass
3333

34+
def record_action_run(self, action: Action, context: AgentContext):
35+
# TODO(dougreid): should this method (or just bit) actually be on AgentContext?
36+
context.completed_steps.append(action)
37+
3438

3539
class LLMAgent(Agent):
3640
"""LLMAgents choose next actions for an AgentService based on interactions with an LLM."""

src/steamship/agents/schema/chathistory.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,11 @@ def append_message_with_role(
144144
text=text, tags=tags, content=content, url=url, mime_type=mime_type
145145
)
146146
# don't index status messages
147-
if self.embedding_index is not None and role is not RoleTag.AGENT:
147+
if self.embedding_index is not None and role not in [
148+
RoleTag.AGENT,
149+
RoleTag.TOOL,
150+
RoleTag.LLM,
151+
]:
148152
chunk_tags = self.text_splitter.chunk_text_to_tags(
149153
block, kind=TagKind.CHAT, name=ChatTag.CHUNK
150154
)

src/steamship/agents/schema/message_selectors.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from pydantic.main import BaseModel
66

77
from steamship import Block
8-
from steamship.data.tags.tag_constants import RoleTag
8+
from steamship.data.tags.tag_constants import RoleTag, TagKind
9+
from steamship.data.tags.tag_utils import get_tag
910

1011

1112
class MessageSelector(BaseModel, ABC):
@@ -29,23 +30,53 @@ def is_assistant_message(block: Block) -> bool:
2930
return role == RoleTag.ASSISTANT
3031

3132

33+
def is_function_message(block: Block) -> bool:
34+
is_function_selection = get_tag(block.tags, kind=TagKind.FUNCTION_SELECTION)
35+
return is_function_selection
36+
37+
38+
def is_tool_function_message(block: Block) -> bool:
39+
is_function_call = get_tag(block.tags, kind=TagKind.ROLE, name=RoleTag.FUNCTION)
40+
return is_function_call
41+
42+
43+
def is_user_history_message(block: Block) -> bool:
44+
return is_user_message(block) or (
45+
is_assistant_message(block) and not is_function_message(block)
46+
)
47+
48+
3249
class MessageWindowMessageSelector(MessageSelector):
3350
k: int
3451

3552
def get_messages(self, messages: List[Block]) -> List[Block]:
3653
msgs = messages[:]
37-
msgs.pop() # don't add the current prompt to the memory
38-
if len(msgs) <= (self.k * 2):
39-
return msgs
40-
54+
# msgs.pop()
55+
have_seen_user_message = False
56+
if is_user_message(msgs[-1]):
57+
have_seen_user_message = True
58+
msgs.pop() # don't add the current prompt to the memory
4159
selected_msgs = []
60+
conversation_messages = 0
4261
limit = self.k * 2
43-
scope = msgs[len(messages) - limit :]
44-
for block in scope:
45-
if is_user_message(block) or is_assistant_message(block):
62+
message_index = len(msgs) - 1
63+
while (conversation_messages < limit) and (message_index > 0):
64+
# TODO(dougreid): i _think_ we don't need the function return if we have a user-assistant pair
65+
# but, for safety here, we try to add non-current function blocks from past iterations.
66+
block = msgs[message_index]
67+
if is_user_message(block):
68+
have_seen_user_message = True
69+
if is_user_history_message(block):
4670
selected_msgs.append(block)
71+
conversation_messages += 1
72+
elif have_seen_user_message and (
73+
is_function_message(block) or is_tool_function_message(block)
74+
):
75+
# conditionally append working function call messages
76+
selected_msgs.append(block)
77+
message_index -= 1
4778

48-
return selected_msgs
79+
return reversed(selected_msgs)
4980

5081

5182
def tokens(block: Block) -> int:
@@ -62,9 +93,11 @@ def get_messages(self, messages: List[Block]) -> List[Block]:
6293
current_tokens = 0
6394

6495
msgs = messages[:]
65-
msgs.pop() # don't add the current prompt to the memory
96+
if is_user_message(msgs[-1]):
97+
msgs.pop() # don't add the current prompt to the memory
98+
6699
for block in reversed(msgs):
67-
if block.chat_role != RoleTag.SYSTEM and current_tokens < self.max_tokens:
100+
if is_user_history_message(block) and current_tokens < self.max_tokens:
68101
block_tokens = tokens(block)
69102
if block_tokens + current_tokens < self.max_tokens:
70103
selected_messages.append(block)

0 commit comments

Comments
 (0)