@@ -510,6 +510,55 @@ def encode(
510510 )
511511
512512 return encoded_inputs ["input_ids" ][0 ]
513+
514+ def batch_encode (
515+ self ,
516+ text : Union [TextInput , PreTokenizedInput ],
517+ text_pair : Optional [Union [TextInput , PreTokenizedInput ]] = None ,
518+ add_special_tokens : bool = True ,
519+ padding : Union [bool , str , PaddingStrategy ] = False ,
520+ truncation : Union [bool , str , TruncationStrategy ] = None ,
521+ max_length : Optional [int ] = None ,
522+ stride : int = 0 ,
523+ padding_side : Optional [bool ] = None ,
524+ return_tensors : Optional [Union [str , TensorType ]] = None ,
525+ ** kwargs ,
526+ ) -> List [List [int ]]:
527+ """
528+ Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary.
529+
530+ Same as doing `self.convert_tokens_to_ids(self.tokenize(text))`.
531+
532+ Args:
533+ text (`str`, `List[str]` or `List[int]`):
534+ The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the
535+ `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
536+ method).
537+ text_pair (`str`, `List[str]` or `List[int]`, *optional*):
538+ Optional second sequence to be encoded. This can be a string, a list of strings (tokenized string using
539+ the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
540+ method).
541+ Returns:
542+ `List[int]`, `torch.Tensor`, `tf.Tensor` or `np.ndarray`: The tokenized ids of the text.
543+ """
544+
545+ encoded_inputs = self ._call_one (
546+ text ,
547+ text_pair ,
548+ add_special_tokens = add_special_tokens ,
549+ padding = padding ,
550+ truncation = truncation ,
551+ max_length = max_length ,
552+ stride = stride ,
553+ is_split_into_words = False ,
554+ padding_side = padding_side ,
555+ return_tensors = return_tensors ,
556+ ** kwargs ,
557+ )
558+
559+ return encoded_inputs ["input_ids" ]
560+
561+
513562
514563 def _init_set (self , key , current_value , value_if_key_not_exist ):
515564 if current_value != None :
@@ -835,8 +884,8 @@ def batch_decode(
835884 def train_new_from_iterator (
836885 self ,
837886 text_iterator ,
838- vocab_size ,
839- special_tokens_map = None ,
887+ vocab_size : int ,
888+ special_tokens_map : dict [ str , str ] | None = None ,
840889 ** kwargs ,
841890 ):
842891 """
@@ -894,11 +943,11 @@ def train_new_from_iterator(
894943
895944 def train_new_from_counts (
896945 self ,
897- word_counts ,
898- vocab_size ,
899- max_token_length = None ,
900- min_word_count = None ,
901- special_tokens_map = None ,
946+ word_counts : dict [ str , int ] ,
947+ vocab_size : int ,
948+ max_token_length : int = 20 ,
949+ min_word_count : int = 0 ,
950+ special_tokens_map : dict [ str , str ] | None = None ,
902951 ** kwargs ,
903952 ):
904953 """
0 commit comments