Skip to content

Commit 7c95e3c

Browse files
new: use a dataclass to store the type of tasks
Signed-off-by: thiswillbeyourgithub <26625900+thiswillbeyourgithub@users.noreply.github.com>
1 parent 5257c5a commit 7c95e3c

File tree

9 files changed

+117
-52
lines changed

9 files changed

+117
-52
lines changed

wdoc/__main__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from wdoc.utils import logger as importedlogger # make sure to setup the logs first
2121
from wdoc.wdoc import wdoc
2222
from wdoc.utils.env import is_out_piped
23-
from wdoc.utils.misc import get_piped_input, tasks_list
23+
from wdoc.utils.misc import get_piped_input
24+
from wdoc.utils.tasks.types import __valid_tasks__
2425
from wdoc.utils.batch_file_loader import infer_filetype, NoInferrableFiletype
2526
from typing import Tuple, List, Dict, Any
2627
import io
@@ -278,13 +279,13 @@ def cli_launcher() -> None:
278279

279280
# turn "wdoc query" into "wdoc --task=query", same for the other tasks
280281
if "task" not in kwargs:
281-
matching_tasks = [t for t in args if t in tasks_list]
282+
matching_tasks = [t for t in args if t in __valid_tasks__]
282283
assert (
283284
len(matching_tasks) != 0
284-
), f"Found no task in the args: '{args}', wdoc needs one of {tasks_list}"
285+
), f"Found no task in the args: '{args}', wdoc needs one of {__valid_tasks__}"
285286
assert (
286287
len(matching_tasks) == 1
287-
), f"Found multiple potential tasks in args: '{args}', wdoc needs one of {tasks_list}"
288+
), f"Found multiple potential tasks in args: '{args}', wdoc needs one of {__valid_tasks__}"
288289
task = matching_tasks[0]
289290
logger.debug(f"Moving task '{task}' from args to kwargs")
290291
args.remove(task)

wdoc/utils/batch_file_loader.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from loguru import logger
2727

2828
from wdoc.utils.env import env, is_out_piped
29+
from wdoc.utils.tasks.types import wdocTask
2930
from wdoc.utils.loaders import (
3031
load_one_doc,
3132
markdownlink_regex,
@@ -122,7 +123,7 @@ def infer_filetype(path: str) -> str:
122123
def batch_load_doc(
123124
llm_name: ModelName,
124125
filetype: str,
125-
task: str,
126+
task: wdocTask,
126127
backend: str,
127128
n_jobs: int,
128129
**cli_kwargs,
@@ -297,7 +298,7 @@ def batch_load_doc(
297298
"exclude",
298299
], "Include or exclude arguments should be reomved at this point"
299300

300-
if "summar" not in task:
301+
if not task.summarize:
301302
# shuffle the list of files to load to make
302303
# the hashing progress bar more representative
303304
to_load = sorted(to_load, key=lambda x: random.random())
@@ -320,7 +321,7 @@ def batch_load_doc(
320321
for i, h in enumerate(doc_hashes):
321322
to_load[i]["file_hash"] = doc_hashes[i]
322323

323-
if "summar" not in task:
324+
if not task.summarize:
324325
# shuffle the list of files again to be random but deterministic:
325326
# keeping only the digits of each hash, then multiplying by the
326327
# index of the filetype by size. This makes sure the doc dicts are
@@ -529,7 +530,8 @@ def batch_load_doc(
529530

530531
# smart deduplication before embedding:
531532
# find the document with the same content_hash, merge their metadata and keep only one
532-
if "summar" not in task and len(docs) > 1:
533+
534+
if task.summarize and len(docs) > 1:
533535
logger.debug("Deduplicating...")
534536
logger.debug("Getting all hash")
535537
content_hash = [d.metadata["content_hash"] for d in docs]

wdoc/utils/loaders/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from wdoc.utils.env import env
2222
from wdoc.utils.errors import MissingDocdictArguments, TimeoutPdfLoaderError
2323
from wdoc.utils.loaders.shared import get_url_title
24+
from wdoc.utils.tasks.types import wdocTask
2425
from wdoc.utils.misc import (
2526
ModelName,
2627
average_word_length,
@@ -116,7 +117,7 @@ def wrapper(*args, **kwargs) -> Union[List[Document], str]:
116117

117118
@wrapper_load_one_doc
118119
def load_one_doc(
119-
task: str,
120+
task: wdocTask,
120121
llm_name: ModelName,
121122
temp_dir: Path,
122123
filetype: str,

wdoc/utils/misc.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737

3838
from wdoc.utils.env import env, is_input_piped, pytest_ongoing
3939
from wdoc.utils.errors import UnexpectedDocDictArgument
40+
from wdoc.utils.tasks.types import wdocTask
4041

4142
import lazy_import
4243

@@ -140,9 +141,6 @@ def language_detector(text: str) -> None:
140141
max_token = 10_000_000
141142
min_lang_prob = 0.50
142143

143-
# list of available tasks
144-
tasks_list = ["query", "summarize", "parse", "search", "summarize_then_query"]
145-
146144
printed_unexpected_api_keys = [False] # to print it only once
147145

148146
# loader specific arguments
@@ -713,7 +711,7 @@ def get_tkn_length(
713711

714712

715713
def get_splitter(
716-
task: str,
714+
task: wdocTask,
717715
modelname: ModelName = DEFAULT_SPLITTER_MODELNAME,
718716
) -> "TextSplitter":
719717
"we don't use the same text splitter depending on the task"
@@ -724,7 +722,7 @@ def get_splitter(
724722
return text_splitters[task][modelname.original]
725723

726724
# if task is parse but we let the model as testing: assume we want a single super large document with no splitting
727-
if task == "parse" and modelname.original == "cliparser/cliparser":
725+
if task.parse and modelname.original == "cliparser/cliparser":
728726
return RecursiveCharacterTextSplitter(
729727
separators=recur_separator,
730728
chunk_size=1e7,
@@ -749,7 +747,7 @@ def get_splitter(
749747
)
750748

751749
# Cap context sizes
752-
if task in ["query", "search"] and max_tokens > env.WDOC_MAX_EMBED_CONTEXT:
750+
if (task.query or task.search) and max_tokens > env.WDOC_MAX_EMBED_CONTEXT:
753751
logger.warning(
754752
f"Capping max_tokens for model {modelname} to WDOC_MAX_EMBED_CONTEXT ({env.WDOC_MAX_EMBED_CONTEXT} instead of {max_tokens}) because in query mode and we can only guess the context size of the embedding model."
755753
)
@@ -762,27 +760,20 @@ def get_splitter(
762760

763761
model_tkn_length = partial(get_tkn_length, modelname=modelname.original)
764762

765-
if task in ["query", "search", "parse"]:
763+
if task.query or task.search or task.parse:
766764
text_splitter = RecursiveCharacterTextSplitter(
767765
separators=recur_separator,
768766
chunk_size=int(3 / 4 * max_tokens), # default 4000
769767
chunk_overlap=500, # default 200
770768
length_function=model_tkn_length,
771769
)
772-
elif task in ["summarize_then_query", "summarize"]:
770+
elif task.summarize:
773771
text_splitter = RecursiveCharacterTextSplitter(
774772
separators=recur_separator,
775773
chunk_size=int(1 / 2 * max_tokens),
776774
chunk_overlap=500,
777775
length_function=model_tkn_length,
778776
)
779-
elif task == "recursive_summary":
780-
text_splitter = RecursiveCharacterTextSplitter(
781-
separators=recur_separator,
782-
chunk_size=int(1 / 4 * max_tokens),
783-
chunk_overlap=300,
784-
length_function=model_tkn_length,
785-
)
786777
else:
787778
raise Exception(task)
788779

wdoc/utils/retrievers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from langchain_core.embeddings import Embeddings
1212

1313
from wdoc.utils.env import env
14+
from wdoc.utils.tasks.types import wdocTask
1415
from wdoc.utils.misc import cache_dir, get_splitter
1516
from wdoc.utils.prompts import multiquery_parser, prompts
1617
from wdoc.utils.customs.compressed_embeddings_cacher import LocalFileStore
@@ -42,7 +43,7 @@ def create_multiquery_retriever(
4243

4344

4445
def create_parent_retriever(
45-
task: str,
46+
task: wdocTask,
4647
loaded_embeddings: Any,
4748
loaded_docs: List[Document],
4849
top_k: int,
@@ -85,7 +86,7 @@ def create_retrievers(
8586
llm,
8687
top_k: int,
8788
relevancy: float,
88-
task: str,
89+
task: wdocTask,
8990
loaded_docs: Optional[List[Document]],
9091
) -> BaseRetriever:
9192
"""Create and return list of retrievers based on query_retrievers setting."""

wdoc/utils/tasks/parse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from wdoc.utils.batch_file_loader import batch_load_doc
1313
from wdoc.utils.logger import debug_exceptions, set_parse_doc_help_md_as_docstring
1414
from wdoc.utils.misc import DocDict, ModelName
15+
from wdoc.utils.tasks.types import wdocTask
1516

1617

1718
@set_parse_doc_help_md_as_docstring
@@ -77,7 +78,7 @@ def parse_doc(
7778
cli_kwargs[k] = v
7879

7980
out = batch_load_doc(
80-
task="parse",
81+
task=wdocTask("parse"),
8182
filetype=filetype,
8283
**cli_kwargs,
8384
**docdict_kwargs,

wdoc/utils/tasks/summarize.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,19 @@
1313
from loguru import logger
1414
import copy
1515
from dataclasses import dataclass, asdict
16+
from functools import partial
17+
from langchain.text_splitter import RecursiveCharacterTextSplitter
1618

1719
from wdoc.utils.logger import (
1820
md_printer,
1921
)
2022
from wdoc.utils.misc import (
2123
# debug_chain,
2224
ModelName,
25+
get_model_max_tokens,
2326
average_word_length,
2427
check_docs_tkn_length,
25-
get_splitter,
28+
recur_separator,
2629
get_tkn_length,
2730
log_and_time_fn,
2831
thinking_answer_parser,
@@ -257,6 +260,23 @@ def summarize_documents(
257260
verbose=llm_verbosity,
258261
)
259262

263+
model_tkn_length = partial(get_tkn_length, modelname=model.original)
264+
if model.is_testing():
265+
max_tokens = 4096
266+
else:
267+
max_tokens = get_model_max_tokens(model)
268+
if max_tokens > env.WDOC_MAX_CHUNK_SIZE:
269+
logger.debug(
270+
f"Capping max_tokens for model {model.original} to the WDOC_MAX_CHUNK_SIZE value ({env.WDOC_MAX_CHUNK_SIZE} instead of {max_tokens})."
271+
)
272+
max_tokens = min(max_tokens, env.WDOC_MAX_CHUNK_SIZE)
273+
splitter = RecursiveCharacterTextSplitter(
274+
separators=recur_separator,
275+
chunk_size=int(1 / 4 * max_tokens),
276+
chunk_overlap=300,
277+
length_function=model_tkn_length,
278+
)
279+
260280
# get reading length of the summary
261281
real_text = "".join([letter for letter in list(summary) if letter.isalpha()])
262282
sum_reading_length = len(real_text) / average_word_length / wpm
@@ -285,10 +305,6 @@ def summarize_documents(
285305
assert "- Chunk " not in summary_text, "Found chunk marker"
286306
assert "- BEFORE RECURSION # " not in summary_text, "Found recursion block"
287307

288-
splitter = get_splitter(
289-
"recursive_summary",
290-
modelname=model,
291-
)
292308
summary_docs = [Document(page_content=summary_text)]
293309
summary_docs = splitter.transform_documents(summary_docs)
294310
assert summary_docs != relevant_docs

wdoc/utils/tasks/types.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from typing import List
2+
from dataclasses import dataclass, field
3+
4+
__valid_tasks__: List[str] = [
5+
"query",
6+
"summarize",
7+
"parse",
8+
"search",
9+
"summarize_then_query",
10+
]
11+
12+
13+
@dataclass
14+
class wdocTask:
15+
original: str
16+
query: bool = field(init=False)
17+
summarize: bool = field(init=False)
18+
parse: bool = field(init=False)
19+
search: bool = field(init=False)
20+
21+
def __post_init__(self):
22+
# default values
23+
self.query = False
24+
self.summarize = False
25+
self.parse = False
26+
self.search = False
27+
28+
# checks
29+
if isinstance(self.original, wdocTask):
30+
self.original = self.original.original
31+
self.original = self.original.replace("summary", "summarize")
32+
assert (
33+
self.original in __valid_tasks__
34+
), f"Received task '{self.original}' is not part of expected tasks: '{__valid_tasks__}'"
35+
36+
# set the actual properties
37+
if "query" in self.original:
38+
self.query = True
39+
if "summarize" in self.original:
40+
self.summarize = True
41+
if "search" in self.original:
42+
self.search = True
43+
if "parse" in self.original:
44+
self.parse = True
45+
46+
def __hash__(self):
47+
"necessary for memoizing"
48+
return self.original.__hash__()
49+
50+
def __str__(self) -> str:
51+
return self.original

0 commit comments

Comments
 (0)