@@ -7,12 +7,12 @@ class Chunking:
77
88 CHUNKER_PARAMS = {
99 'token' : ['chunk_size' , 'chunk_overlap' , 'tokenizer' ],
10- 'word' : ['chunk_size' , 'chunk_overlap' , 'tokenizer ' ],
11- 'sentence' : ['chunk_size' , 'chunk_overlap' , 'tokenizer ' ],
12- 'semantic ' : ['chunk_size' , 'embedding_model' , 'tokenizer ' ],
13- 'sdpm ' : ['chunk_size' , 'embedding_model' , 'tokenizer ' ],
14- 'late ' : ['chunk_size' , 'embedding_model' , 'tokenizer ' ],
15- 'recursive ' : ['chunk_size' , 'tokenizer' ]
10+ 'word' : ['chunk_size' , 'chunk_overlap' , 'tokenizer_or_token_counter ' ],
11+ 'sentence' : ['chunk_size' , 'chunk_overlap' , 'tokenizer_or_token_counter ' ],
12+ 'recursive ' : ['chunk_size' , 'tokenizer_or_token_counter ' ],
13+ 'semantic ' : ['chunk_size' , 'embedding_model' ],
14+ 'sdpm ' : ['chunk_size' , 'embedding_model' ],
15+ 'late ' : ['chunk_size' , 'embedding_model' ],
1616 }
1717
1818 @cached_property
@@ -48,7 +48,7 @@ def __init__(
4848 chunker_type : str = 'token' ,
4949 chunk_size : int = 512 ,
5050 chunk_overlap : int = 128 ,
51- tokenizer : str = "gpt2" ,
51+ tokenizer_or_token_counter : str = "gpt2" ,
5252 embedding_model : Optional [Union [str , Any ]] = None ,
5353 ** kwargs
5454 ):
@@ -62,7 +62,7 @@ def __init__(
6262 self .chunker_type = chunker_type
6363 self .chunk_size = chunk_size
6464 self .chunk_overlap = chunk_overlap
65- self .tokenizer = tokenizer
65+ self .tokenizer_or_token_counter = tokenizer_or_token_counter
6666 self ._embedding_model = embedding_model
6767 self .kwargs = kwargs
6868
@@ -89,11 +89,10 @@ def _get_chunker_params(self) -> Dict[str, Any]:
8989 if 'chunk_overlap' in allowed_params :
9090 params ['chunk_overlap' ] = self .chunk_overlap
9191
92- if 'tokenizer' in allowed_params :
93- if self .chunker_type in ['semantic' , 'sdpm' , 'late' ]:
94- params ['tokenizer' ] = self .embedding_model .get_tokenizer_or_token_counter ()
95- else :
96- params ['tokenizer' ] = self .tokenizer
92+ if 'tokenizer_or_token_counter' in allowed_params :
93+ params ['tokenizer_or_token_counter' ] = self .tokenizer_or_token_counter
94+ elif 'tokenizer' in allowed_params :
95+ params ['tokenizer' ] = self .tokenizer_or_token_counter
9796
9897 if 'embedding_model' in allowed_params :
9998 params ['embedding_model' ] = self .embedding_model
0 commit comments