3737
3838from wdoc .utils .env import env , is_input_piped , pytest_ongoing
3939from wdoc .utils .errors import UnexpectedDocDictArgument
40+ from wdoc .utils .tasks .types import wdocTask
4041
4142import lazy_import
4243
@@ -140,9 +141,6 @@ def language_detector(text: str) -> None:
140141max_token = 10_000_000
141142min_lang_prob = 0.50
142143
143- # list of available tasks
144- tasks_list = ["query" , "summarize" , "parse" , "search" , "summarize_then_query" ]
145-
146144printed_unexpected_api_keys = [False ] # to print it only once
147145
148146# loader specific arguments
@@ -713,7 +711,7 @@ def get_tkn_length(
713711
714712
715713def get_splitter (
716- task : str ,
714+ task : wdocTask ,
717715 modelname : ModelName = DEFAULT_SPLITTER_MODELNAME ,
718716) -> "TextSplitter" :
719717 "we don't use the same text splitter depending on the task"
@@ -724,7 +722,7 @@ def get_splitter(
724722 return text_splitters [task ][modelname .original ]
725723
726724 # if task is parse but we let the model as testing: assume we want a single super large document with no splitting
727- if task == " parse" and modelname .original == "cliparser/cliparser" :
725+ if task . parse and modelname .original == "cliparser/cliparser" :
728726 return RecursiveCharacterTextSplitter (
729727 separators = recur_separator ,
730728 chunk_size = 1e7 ,
@@ -749,7 +747,7 @@ def get_splitter(
749747 )
750748
751749 # Cap context sizes
752- if task in [ " query" , " search" ] and max_tokens > env .WDOC_MAX_EMBED_CONTEXT :
750+ if ( task . query or task . search ) and max_tokens > env .WDOC_MAX_EMBED_CONTEXT :
753751 logger .warning (
754752 f"Capping max_tokens for model { modelname } to WDOC_MAX_EMBED_CONTEXT ({ env .WDOC_MAX_EMBED_CONTEXT } instead of { max_tokens } ) because in query mode and we can only guess the context size of the embedding model."
755753 )
@@ -762,27 +760,20 @@ def get_splitter(
762760
763761 model_tkn_length = partial (get_tkn_length , modelname = modelname .original )
764762
765- if task in [ " query" , " search" , " parse" ] :
763+ if task . query or task . search or task . parse :
766764 text_splitter = RecursiveCharacterTextSplitter (
767765 separators = recur_separator ,
768766 chunk_size = int (3 / 4 * max_tokens ), # default 4000
769767 chunk_overlap = 500 , # default 200
770768 length_function = model_tkn_length ,
771769 )
772- elif task in [ "summarize_then_query" , " summarize" ] :
770+ elif task . summarize :
773771 text_splitter = RecursiveCharacterTextSplitter (
774772 separators = recur_separator ,
775773 chunk_size = int (1 / 2 * max_tokens ),
776774 chunk_overlap = 500 ,
777775 length_function = model_tkn_length ,
778776 )
779- elif task == "recursive_summary" :
780- text_splitter = RecursiveCharacterTextSplitter (
781- separators = recur_separator ,
782- chunk_size = int (1 / 4 * max_tokens ),
783- chunk_overlap = 300 ,
784- length_function = model_tkn_length ,
785- )
786777 else :
787778 raise Exception (task )
788779
0 commit comments