Skip to content

Commit 4a1e2c7

Browse files
committed
accommodate mllama model.language_model issue and gptneoxtokenizer and llama tokenizer
1 parent 2afffd0 commit 4a1e2c7

1 file changed

Lines changed: 34 additions & 9 deletions

File tree

tuning/data/tokenizer_utils.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,21 +44,42 @@ def get_special_tokens_dict(
4444

4545
special_tokens_dict = {}
4646
if not tokenizer_name_or_path:
47-
# TODO: understand if we need to hardcode these here or just use defaults in model
48-
if isinstance(
49-
tokenizer, (transformers.LlamaTokenizer, transformers.LlamaTokenizerFast)
50-
):
47+
# # TODO: understand if we need to hardcode these here or just use defaults in model
48+
# if isinstance(
49+
# tokenizer, (transformers.LlamaTokenizer, transformers.LlamaTokenizerFast)
50+
# ):
51+
llama_classes = tuple(
52+
cls for cls in [
53+
getattr(transformers, "LlamaTokenizer", None),
54+
getattr(transformers, "LlamaTokenizerFast", None),
55+
] if cls is not None
56+
)
57+
is_llama_tokenizer = (
58+
(bool(llama_classes) and isinstance(tokenizer, llama_classes))
59+
or "llama" in (getattr(tokenizer, "name_or_path", "") or "").lower()
60+
)
61+
62+
gpt_neox_classes = tuple(
63+
cls for cls in [
64+
getattr(transformers, "GPTNeoXTokenizerFast", None),
65+
getattr(transformers, "GPTNeoXTokenizer", None),
66+
] if cls is not None
67+
)
68+
69+
if is_llama_tokenizer:
5170
special_tokens_dict["bos_token"] = "<s>"
5271
special_tokens_dict["eos_token"] = "</s>"
5372
special_tokens_dict["unk_token"] = "<unk>"
5473
special_tokens_dict["pad_token"] = "<pad>"
5574
elif isinstance(
56-
tokenizer, (transformers.GPT2Tokenizer, transformers.GPTNeoXTokenizerFast)
75+
# tokenizer, (transformers.GPT2Tokenizer, transformers.GPTNeoXTokenizerFast)
76+
tokenizer, (transformers.GPT2Tokenizer, *gpt_neox_classes)
5777
):
5878
special_tokens_dict["pad_token"] = "<pad>"
5979

6080
# Add special tokens only when a custom tokenizer is not passed
61-
if tokenizer.pad_token is None:
81+
# if tokenizer.pad_token is None:
82+
if tokenizer.pad_token is None or "pad_token" in special_tokens_dict:
6283
logger.warning("PAD token set to default, missing in tokenizer")
6384
special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN
6485
if tokenizer.eos_token is None:
@@ -102,7 +123,8 @@ def tokenizer_and_embedding_resize(
102123
dict: Metadata on number of added tokens.
103124
"""
104125
num_new_tokens = tokenizer.add_special_tokens(
105-
special_tokens_dict=special_tokens_dict, replace_additional_special_tokens=False
126+
special_tokens_dict=special_tokens_dict,
127+
# replace_additional_special_tokens=False
106128
)
107129
embedding_size = int(multiple_of * math.ceil(len(tokenizer) / multiple_of))
108130
num_new_tokens = num_new_tokens + embedding_size - len(tokenizer)
@@ -119,8 +141,11 @@ def tokenizer_and_embedding_resize(
119141
model.set_input_embeddings(resized_input_embeddings)
120142

121143
# Resize vocab size when embeddings updated for Mllama models
122-
if model.language_model.vocab_size != embedding_size:
123-
model.language_model.vocab_size = embedding_size
144+
# if model.language_model.vocab_size != embedding_size:
145+
# model.language_model.vocab_size = embedding_size
146+
if model.model.vocab_size != embedding_size:
147+
model.model.vocab_size = embedding_size
148+
124149
else:
125150
model.resize_token_embeddings(embedding_size)
126151

0 commit comments

Comments
 (0)