-
Notifications
You must be signed in to change notification settings - Fork 57
Expand file tree
/
Copy pathoutput_parser.py
More file actions
145 lines (127 loc) · 5.87 KB
/
Copy pathoutput_parser.py
File metadata and controls
145 lines (127 loc) · 5.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import json
import re
import string
from json import JSONDecodeError
from typing import Dict, List, Optional
from steamship import Block, MimeTypes, Steamship, Tag
from steamship.agents.schema import Action, AgentContext, FinishAction, OutputParser, Tool
from steamship.data.tags.tag_constants import RoleTag, TagKind
from steamship.utils.utils import is_valid_uuid4
def is_punctuation(text: str):
for c in text:
if c not in string.punctuation:
return False
return True
class FunctionsBasedOutputParser(OutputParser):
tools_lookup_dict: Optional[Dict[str, Tool]] = None
def __init__(self, **kwargs):
tools_lookup_dict = {tool.name: tool for tool in kwargs.pop("tools", [])}
super().__init__(tools_lookup_dict=tools_lookup_dict, **kwargs)
def _extract_action_from_function_call(self, text: str, context: AgentContext) -> Action:
wrapper = json.loads(text)
fc = wrapper.get("function_call")
name = fc.get("name", "")
if name.startswith("functions."):
name = name[len("functions.") :] # occasionally, OpenAI prepends "functions."
tool = self.tools_lookup_dict.get(name, None)
if tool is None:
raise RuntimeError(
f"Could not find tool from function call: `{name}`. Known tools: {self.tools_lookup_dict.keys()}"
)
input_blocks = []
arguments = fc.get("arguments")
if arguments:
try:
args = json.loads(arguments)
if text := args.get("text"):
input_blocks.append(
Block(
text=text,
tags=[Tag(kind=TagKind.FUNCTION_ARG, name="text")],
mime_type=MimeTypes.TXT,
)
)
elif uuid_arg := args.get("uuid"):
existing_block = Block.get(context.client, _id=uuid_arg)
tag = Tag.create(
existing_block.client,
file_id=existing_block.file_id,
block_id=existing_block.id,
kind=TagKind.FUNCTION_ARG,
name="uuid",
)
existing_block.tags.append(tag)
input_blocks.append(existing_block)
except json.decoder.JSONDecodeError:
if isinstance(arguments, str):
if is_valid_uuid4(arguments):
existing_block = Block.get(context.client, _id=arguments)
tag = Tag.create(
existing_block.client,
file_id=existing_block.file_id,
block_id=existing_block.id,
kind=TagKind.FUNCTION_ARG,
name="uuid",
)
existing_block.tags.append(tag)
input_blocks.append(existing_block)
else:
input_blocks.append(
Block(
text=arguments,
tags=[Tag(kind=TagKind.FUNCTION_ARG, name="text")],
mime_type=MimeTypes.TXT,
)
)
return Action(tool=tool.name, input=input_blocks, context=context)
@staticmethod
def _blocks_from_text(client: Steamship, text: str) -> List[Block]:
last_response = text.split("AI:")[-1].strip()
block_id_regex = r"(?:(?:\[|\()?Block)?\(?([A-F0-9]{8}\-[A-F0-9]{4}\-[A-F0-9]{4}\-[A-F0-9]{4}\-[A-F0-9]{12})\)?(?:(\]|\)))?"
remaining_text = last_response
result_blocks: List[Block] = []
while remaining_text is not None and len(remaining_text.strip()) > 0:
if is_punctuation(remaining_text.strip()):
remaining_text = ""
continue
match = re.search(block_id_regex, remaining_text)
if match:
pre_block_text = FunctionsBasedOutputParser._remove_block_prefix(
candidate=remaining_text[0 : match.start()]
)
if len(pre_block_text) > 0:
result_blocks.append(Block(text=pre_block_text))
result_blocks.append(Block.get(client, _id=match.group(1)))
remaining_text = FunctionsBasedOutputParser._remove_block_suffix(
remaining_text[match.end() :]
)
else:
result_blocks.append(Block(text=remaining_text))
remaining_text = ""
return result_blocks
@staticmethod
def _remove_block_prefix(candidate: str) -> str:
removed = candidate
if removed.endswith("(Block") or removed.endswith("[Block"):
removed = removed[len("Block") + 1 :]
elif removed.endswith("Block"):
removed = removed[len("Block") :]
return removed
@staticmethod
def _remove_block_suffix(candidate: str) -> str:
removed = candidate
if removed.startswith(")") or removed.endswith("]"):
removed = removed[1:]
return removed
def parse(self, text: str, context: AgentContext) -> Action:
if "function_call" in text:
try:
# catch invalid JSON. If it is not valid JSON, just treat as "regular" message
return self._extract_action_from_function_call(text, context)
except JSONDecodeError:
pass
finish_blocks = FunctionsBasedOutputParser._blocks_from_text(context.client, text)
for finish_block in finish_blocks:
finish_block.set_chat_role(RoleTag.ASSISTANT)
finish_block.set_request_id(context.request_id)
return FinishAction(output=finish_blocks, context=context)