Skip to content

Commit 372dca4

Browse files
eobdouglas-reid
andauthored
feat: Add source attribution reporting on DocumentIndexingPipeline + Tests (#490)
Right now the document indexing pipeline does not properly include source material attribution. This PR adds and tests that inclusion. The source attribution is both automatic (fileId, blockId, page) and user-controlled (via the metadata argument passed in at index request time) --------- Co-authored-by: Douglas Reid <douglas-reid@users.noreply.github.com>
1 parent 603d88a commit 372dca4

14 files changed

Lines changed: 208 additions & 35 deletions

File tree

src/steamship/agents/functional/output_parser.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def _extract_action_from_function_call(self, text: str, context: AgentContext) -
2828
wrapper = json.loads(text)
2929
fc = wrapper.get("function_call")
3030
name = fc.get("name", "")
31+
if name.startswith("functions."):
32+
name = name[len("functions.") :] # occasionally, OpenAI prepends "functions."
3133
tool = self.tools_lookup_dict.get(name, None)
3234
if tool is None:
3335
raise RuntimeError(

src/steamship/agents/tools/question_answering/vector_search_qa_tool.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
import logging
33
from typing import Any, List, Optional, Union
44

5-
from steamship import Block, Tag, Task
5+
from steamship import Block, DocTag, Tag, Task
66
from steamship.agents.llms import OpenAI
77
from steamship.agents.logging import AgentLogging
88
from steamship.agents.schema import AgentContext
99
from steamship.agents.tools.question_answering.vector_search_tool import VectorSearchTool
1010
from steamship.agents.utils import get_llm, with_llm
11+
from steamship.data import TagKind
1112
from steamship.utils.repl import ToolREPL
1213

1314
DEFAULT_QUESTION_ANSWERING_PROMPT = (
@@ -45,11 +46,16 @@ def answer_question(self, question: str, context: AgentContext) -> List[Block]:
4546
task.wait()
4647

4748
source_texts = []
49+
source_metadata = []
4850

4951
for item in task.output.items:
5052
if item.tag and item.tag.text:
5153
item_data = {"text": item.tag.text}
5254
source_texts.append(self.source_document_prompt.format(**item_data))
55+
_metadata = {}
56+
if item.tag.value:
57+
_metadata.update(item.tag.value)
58+
source_metadata.append(_metadata)
5359

5460
final_prompt = self.question_answering_prompt.format(
5561
**{"source_text": "\n".join(source_texts), "question": question}
@@ -65,8 +71,16 @@ def answer_question(self, question: str, context: AgentContext) -> List[Block]:
6571
"prompt": final_prompt,
6672
},
6773
)
68-
69-
return get_llm(context, default=OpenAI(client=context.client)).complete(prompt=final_prompt)
74+
output_blocks = get_llm(context, default=OpenAI(client=context.client)).complete(
75+
prompt=final_prompt
76+
)
77+
for output_block in output_blocks:
78+
if output_block.tags is None:
79+
output_block.tags = []
80+
output_block.tags.append(
81+
Tag(kind=TagKind.DOCUMENT, name=DocTag.SOURCE, value={"sources": source_metadata})
82+
)
83+
return output_blocks
7084

7185
def run(self, tool_input: List[Block], context: AgentContext) -> Union[List[Block], Task[Any]]:
7286
"""Answers questions with the assistance of an Embedding Index plugin.

src/steamship/data/tags/tag_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class DocTag(str, Enum):
7373
CHAPTER = "chapter"
7474
TEXT = "text"
7575
CHAT = "chat"
76+
METADATA = "metadata"
7677

7778
@staticmethod
7879
def from_html_tag(tagname: Optional[str]) -> Optional["DocTag"]: # noqa: C901

src/steamship/invocable/mixins/blockifier_mixin.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def blockify(
3838

3939
_mime_type = mime_type or file.mime_type
4040
if not _mime_type:
41+
update_file_status(self.client, file, "Failed Blockifying")
4142
raise SteamshipError(
4243
message=f"No MIME Type found for file {file.id}. Unable to blockify."
4344
)
@@ -54,6 +55,7 @@ def blockify(
5455
plugin_instance = self.client.use_plugin("markdown-blockifier-default")
5556

5657
if not plugin_instance:
58+
update_file_status(self.client, file, "Failed Blockifying")
5759
raise SteamshipError(
5860
message=f"Unable to blockify file {file.id}. MIME Type {_mime_type} unsupported"
5961
)

src/steamship/invocable/mixins/file_importer_mixin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def import_url_to_file_and_task(self, url: str) -> Tuple[File, Optional[Task]]:
9595

9696
@post("/import_url")
9797
def import_url(self, url: str) -> File:
98-
"""Import the URL to a Steamship File. Actual import will be scheduled Async."""
98+
"""Import the URL to a Steamship File. Actual import will be scheduled async."""
9999
file, task = self.import_url_to_file_and_task(url)
100100
return file
101101

src/steamship/invocable/mixins/indexer_mixin.py

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from typing import Optional, cast
22

33
from steamship import Block, DocTag, File, Steamship, Tag
4-
from steamship.data import TagValueKey
4+
from steamship.data import TagKind, TagValueKey
55
from steamship.data.plugin.index_plugin_instance import EmbeddingIndexPluginInstance, SearchResults
66
from steamship.invocable import post
77
from steamship.invocable.package_mixin import PackageMixin
88
from steamship.utils.file_tags import update_file_status
9+
from steamship.utils.text_chunker import chunk_text
910

1011
DEFAULT_EMBEDDING_INDEX_CONFIG = {
1112
"embedder": {
@@ -71,13 +72,17 @@ def _get_index(self, index_handle: Optional[str] = None) -> EmbeddingIndexPlugin
7172
def index_text(
7273
self, text: str, metadata: Optional[dict] = None, index_handle: Optional[str] = None
7374
) -> bool:
75+
"""Load text into an embedding index.
76+
77+
Optional arguments:
78+
- index_handle (uses your default index if blank)
79+
- metadata (returned on embedding results for source attribution)
80+
"""
7481
tags = []
75-
for i in range(0, len(text), self.context_window_size):
76-
# Calculate the extent of the window plus the overlap at the edges
77-
min_range = max(0, i - self.context_window_overlap)
78-
max_range = i + self.context_window_size + self.context_window_overlap
79-
chunk = text[min_range:max_range]
80-
tags.append(Tag(text=chunk, metadata=metadata))
82+
for chunk in chunk_text(
83+
text, chunk_size=self.context_window_size, chunk_overlap=self.context_window_overlap
84+
):
85+
tags.append(Tag(text=chunk, value=metadata))
8186
self._get_index(index_handle).insert(tags)
8287
return True
8388

@@ -88,9 +93,9 @@ def _index_block(
8893
_metadata = {}
8994
if metadata:
9095
_metadata.update(metadata)
96+
9197
_metadata.update(
9298
{
93-
"source": "",
9499
"file_id": block.file_id,
95100
"block_id": block.id,
96101
"page": page_id,
@@ -103,13 +108,18 @@ def _index_block(
103108
def index_block(
104109
self, block_id: str, metadata: Optional[dict] = None, index_handle: Optional[str] = None
105110
):
111+
"""Load a Steamship Block into an embedding index.
112+
113+
Optional arguments:
114+
- index_handle (uses your default index if blank)
115+
- metadata (returned on embedding results for source attribution)
116+
"""
106117
block = Block.get(self.client, _id=block_id)
107118
page_id = self._get_page(block)
108119
_metadata = {}
109120
_metadata.update(metadata)
110121
_metadata.update(
111122
{
112-
"source": "",
113123
"file_id": block.file_id,
114124
"block_id": block.id,
115125
"page": page_id,
@@ -122,11 +132,29 @@ def index_block(
122132
def index_file(
123133
self, file_id: str, metadata: Optional[dict] = None, index_handle: Optional[str] = None
124134
) -> bool:
135+
"""Load a Steamship File into an embedding index.
136+
137+
Optional arguments:
138+
- index_handle (uses your default index if blank)
139+
- metadata (returned on embedding results for source attribution)
140+
"""
125141
file = File.get(self.client, _id=file_id)
126142
update_file_status(self.client, file, "Indexing")
127143

144+
_metadata = {}
145+
if file.mime_type:
146+
_metadata["mime_type"] = file.mime_type
147+
148+
for tag in file.tags or []:
149+
if tag.kind == TagKind.DOCUMENT and tag.name == DocTag.TITLE:
150+
if title := tag.value.get(TagValueKey.STRING_VALUE):
151+
_metadata["title"] = title
152+
153+
if metadata:
154+
_metadata.update(metadata)
155+
128156
for block in file.blocks or []:
129-
self._index_block(block, metadata=metadata, index_handle=index_handle)
157+
self._index_block(block, metadata=_metadata, index_handle=index_handle)
130158

131159
update_file_status(self.client, file, "Indexed")
132160
return True
@@ -135,6 +163,11 @@ def index_file(
135163
def search_index(
136164
self, query: str, index_handle: Optional[str] = None, k: int = 5
137165
) -> SearchResults:
166+
"""Search an embedding index.
167+
168+
Optional arguments:
169+
- index_handle (uses your default index if blank)
170+
"""
138171
index = self._get_index(index_handle)
139172
task = index.search(query, k)
140173
return task.wait()

src/steamship/invocable/mixins/indexer_pipeline_mixin.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,18 @@ def index_url(
5353
index_handle: Optional[str] = None,
5454
mime_type: Optional[str] = None,
5555
) -> Task:
56+
"""Load a URL into an embedding index.
57+
58+
URL Types supported:
59+
- PDF (Text)
60+
- TXT and Markdown
61+
- YouTube (Though failure rate is high)
62+
63+
Optional arguments:
64+
- mime_type (if it can be guessed by the Content-Type header or the URL schema)
65+
- index_handle (uses your default index if blank)
66+
- metadata (returned on embedding results for source attribution)
67+
"""
5668
# Step 1: Import the URL
5769
file, task = self.importer_mixin.import_url_to_file_and_task(url)
5870

@@ -66,13 +78,14 @@ def index_url(
6678
)
6779

6880
# Step 3: Index the File
81+
_metadata = {"url": url}
82+
if metadata is not None:
83+
_metadata.update(metadata)
84+
6985
index_task = self.invocable.invoke_later(
7086
method="index_file",
7187
wait_on_tasks=[blockify_task],
72-
arguments={
73-
"file_id": file.id,
74-
"index_handle": index_handle,
75-
},
88+
arguments={"file_id": file.id, "index_handle": index_handle, "metadata": _metadata},
7689
)
7790

7891
# Step 4: Set the File Status to 'indexed'

src/steamship/utils/repl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from termcolor import colored # noqa: F401
2020
except ImportError:
2121

22-
def colored(text: str, **kwargs):
22+
def colored(text: str, color: str, **kwargs):
2323
print(text)
2424

2525

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import logging
2+
3+
4+
def chunk_text(text: str, chunk_size: int = 200, chunk_overlap: int = 50):
5+
"""Chunk text for embedding and insertion into an embedding index."""
6+
if chunk_size < 1:
7+
logging.warning(f"chunk_size was f{chunk_size}. Setting to 200")
8+
chunk_size = 200
9+
10+
if chunk_overlap < 0:
11+
logging.warning(f"chunk_overlap was f{chunk_overlap}. Setting to 0")
12+
chunk_overlap = 0
13+
14+
if chunk_overlap > chunk_size:
15+
logging.warning(f"chunk_size was f{chunk_size}. Setting to chunk_size - 1 of {chunk_size}")
16+
chunk_overlap = chunk_size - 1 if chunk_size > 1 else 1
17+
18+
step_size = chunk_size - chunk_overlap
19+
20+
for i in range(0, len(text), step_size):
21+
yield text[i : i + chunk_size]

tests/steamship_tests/agents/tools/test_fact_learner_tools.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ def test_fact_learner_agent_service(client: Steamship):
1414
agent.invoke("prompt", prompt="please remember that my name is Inigo Montoya")
1515
agent.invoke("prompt", prompt="please remember that I am skilled swordsman")
1616

17-
answer_blocks = agent.invoke("prompt", prompt="what is my name?")
17+
answer_blocks = agent.invoke("prompt", prompt="Is my name Inigo?")
1818
assert len(answer_blocks) == 1
19-
assert "Inigo Montoya" in Block(**answer_blocks[0]).text
19+
assert "yes" in Block(**answer_blocks[0]).text.lower()
2020

2121
answer_blocks = agent.invoke("prompt", prompt="what do I know how to do well?")
2222
assert len(answer_blocks) == 1
23-
assert "sword" in Block(**answer_blocks[0]).text
23+
assert "sword" in Block(**answer_blocks[0]).text.lower()

0 commit comments

Comments
 (0)