Skip to content

Commit 610f21d

Browse files
fix: sieve of documents was not getting updated if top_k changed
1 parent 4f8494f commit 610f21d

2 files changed

Lines changed: 10 additions & 7 deletions

File tree

WDoc/WDoc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,7 +1426,7 @@ def retrieve_documents(inputs):
14261426
}
14271427
rag_chain = (
14281428
retrieve_documents
1429-
| sieve_documents(top_k=self.top_k, max_top_k=self.max_top_k)
1429+
| sieve_documents(instance=self)
14301430
| refilter_documents
14311431
)
14321432
tried_top_k = []
@@ -1576,7 +1576,7 @@ def retrieve_documents(inputs):
15761576

15771577
rag_chain = (
15781578
retrieve_documents
1579-
| sieve_documents(top_k=self.top_k, max_top_k=self.max_top_k)
1579+
| sieve_documents(instance=self)
15801580
| pbar_chain(
15811581
llm=self.eval_llm,
15821582
len_func="len(inputs['unfiltered_docs'])",

WDoc/utils/tasks/query.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,25 @@ def check_intermediate_answer(ans: str) -> bool:
4848

4949

5050
@optional_typecheck
51-
def sieve_documents(top_k: int, max_top_k: int) -> RunnableLambda:
51+
def sieve_documents(instance) -> RunnableLambda:
5252
"""cap the number of retrieved documents as if multiple retrievers are used
5353
we can end up with a lot more document!
5454
"""
55-
assert max_top_k >= top_k
5655
@chain
5756
@optional_typecheck
5857
def _sieve(inputs: dict) -> dict:
5958
assert "question_to_answer" in inputs, inputs.keys()
6059
assert "unfiltered_docs" in inputs, inputs.keys()
61-
if len(inputs) > top_k:
60+
# we have to pass an instance otherwise we can't know if the top_k got updated
61+
assert hasattr(instance, "top_k")
62+
assert hasattr(instance, "max_top_k")
63+
assert instance.max_top_k >= instance.top_k
64+
if len(inputs) > instance.top_k:
6265
red(
6366
"Number of documents found via embeddings was "
64-
f"'{inputs['unfiltered_docs']}' which is > top_k ({top_k}) "
67+
f"'{inputs['unfiltered_docs']}' which is > top_k ({instance.top_k}) "
6568
"so we crop")
66-
inputs["unfiltered_docs"] = inputs["unfiltered_docs"][:top_k]
69+
inputs["unfiltered_docs"] = inputs["unfiltered_docs"][:instance.top_k]
6770
return inputs
6871
return _sieve
6972

0 commit comments

Comments
 (0)