@@ -44,21 +44,42 @@ def get_special_tokens_dict(
4444
4545 special_tokens_dict = {}
4646 if not tokenizer_name_or_path :
47- # TODO: understand if we need to hardcode these here or just use defaults in model
48- if isinstance (
49- tokenizer , (transformers .LlamaTokenizer , transformers .LlamaTokenizerFast )
50- ):
47+ # # TODO: understand if we need to hardcode these here or just use defaults in model
48+ # if isinstance(
49+ # tokenizer, (transformers.LlamaTokenizer, transformers.LlamaTokenizerFast)
50+ # ):
51+ llama_classes = tuple (
52+ cls for cls in [
53+ getattr (transformers , "LlamaTokenizer" , None ),
54+ getattr (transformers , "LlamaTokenizerFast" , None ),
55+ ] if cls is not None
56+ )
57+ is_llama_tokenizer = (
58+ (bool (llama_classes ) and isinstance (tokenizer , llama_classes ))
59+ or "llama" in (getattr (tokenizer , "name_or_path" , "" ) or "" ).lower ()
60+ )
61+
62+ gpt_neox_classes = tuple (
63+ cls for cls in [
64+ getattr (transformers , "GPTNeoXTokenizerFast" , None ),
65+ getattr (transformers , "GPTNeoXTokenizer" , None ),
66+ ] if cls is not None
67+ )
68+
69+ if is_llama_tokenizer :
5170 special_tokens_dict ["bos_token" ] = "<s>"
5271 special_tokens_dict ["eos_token" ] = "</s>"
5372 special_tokens_dict ["unk_token" ] = "<unk>"
5473 special_tokens_dict ["pad_token" ] = "<pad>"
5574 elif isinstance (
56- tokenizer , (transformers .GPT2Tokenizer , transformers .GPTNeoXTokenizerFast )
75+ # tokenizer, (transformers.GPT2Tokenizer, transformers.GPTNeoXTokenizerFast)
76+ tokenizer , (transformers .GPT2Tokenizer , * gpt_neox_classes )
5777 ):
5878 special_tokens_dict ["pad_token" ] = "<pad>"
5979
6080 # Add special tokens only when a custom tokenizer is not passed
61- if tokenizer .pad_token is None :
81+ # if tokenizer.pad_token is None:
82+ if tokenizer .pad_token is None or "pad_token" in special_tokens_dict :
6283 logger .warning ("PAD token set to default, missing in tokenizer" )
6384 special_tokens_dict ["pad_token" ] = configs .DEFAULT_PAD_TOKEN
6485 if tokenizer .eos_token is None :
@@ -102,7 +123,8 @@ def tokenizer_and_embedding_resize(
102123 dict: Metadata on number of added tokens.
103124 """
104125 num_new_tokens = tokenizer .add_special_tokens (
105- special_tokens_dict = special_tokens_dict , replace_additional_special_tokens = False
126+ special_tokens_dict = special_tokens_dict ,
127+ # replace_additional_special_tokens=False
106128 )
107129 embedding_size = int (multiple_of * math .ceil (len (tokenizer ) / multiple_of ))
108130 num_new_tokens = num_new_tokens + embedding_size - len (tokenizer )
@@ -119,8 +141,11 @@ def tokenizer_and_embedding_resize(
119141 model .set_input_embeddings (resized_input_embeddings )
120142
121143 # Resize vocab size when embeddings updated for Mllama models
122- if model .language_model .vocab_size != embedding_size :
123- model .language_model .vocab_size = embedding_size
144+ # if model.language_model.vocab_size != embedding_size:
145+ # model.language_model.vocab_size = embedding_size
146+ if model .model .vocab_size != embedding_size :
147+ model .model .vocab_size = embedding_size
148+
124149 else :
125150 model .resize_token_embeddings (embedding_size )
126151
0 commit comments