11import argparse
22import os
3- from transformers import AutoModelForCausalLM
3+ from transformers import AutoModelForCausalLM , AutoTokenizer
44from transformers import T5ForConditionalGeneration
55from torch_save_utils import PINNED_BUFFER_MB
66
@@ -23,10 +23,13 @@ def _get_hf_model(tag):
2323 model_name = HF_MODELS_DICT [tag ]
2424 if tag == TINY_T5 :
2525 model = T5ForConditionalGeneration .from_pretrained (model_name )
26+
2627 else :
2728 model = AutoModelForCausalLM .from_pretrained (model_name )
29+ tokenizer = AutoTokenizer .from_pretrained (model_name )
30+
2831
29- return model , model_name , tag
32+ return model , tokenizer , model_name , tag
3033
3134def get_model (model_tag ):
3235 return _get_hf_model (model_tag )
@@ -108,6 +111,13 @@ def parse_arguments():
108111 action = 'store_true' ,
109112 help = 'Disable double buffering of i/o buffer.' )
110113
114+ parser .add_argument ('--safetensors' ,
115+ action = 'store_true' ,
116+ help = 'Use safetensors load/save.' )
117+
118+ parser .add_argument ('--regular_torch_save' ,
119+ action = 'store_true' ,
120+ help = 'Use vanilla torch.save.' )
111121
112122 #parser.add_argument('--single_writer', action='store_true', help='Disable parallel rank writes of data parallel (replicated) state')
113123
0 commit comments