Skip to content

Commit aa1110c

Browse files
committed
Use accelerator pin memory
1 parent 27091f4 commit aa1110c

3 files changed

Lines changed: 14 additions & 4 deletions

File tree

deepnvme/ds_io/ds_io_sweep.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ fi
2323
if [[ ${XFER} == "cpu" ]]; then
2424
xfer_opt=""
2525
elif [[ ${XFER} == "gpu" ]]; then
26-
xfer_opt="--gpu"
26+
xfer_opt="--gpu --use_accelerator_pin_memory"
2727
elif [[ ${XFER} == "gds" ]]; then
2828
xfer_opt="--gpu --use_gds"
2929
else

deepnvme/model_checkpoint/save_model_utils.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import argparse
22
import os
3-
from transformers import AutoModelForCausalLM
3+
from transformers import AutoModelForCausalLM, AutoTokenizer
44
from transformers import T5ForConditionalGeneration
55
from torch_save_utils import PINNED_BUFFER_MB
66

@@ -23,10 +23,13 @@ def _get_hf_model(tag):
2323
model_name = HF_MODELS_DICT[tag]
2424
if tag == TINY_T5:
2525
model = T5ForConditionalGeneration.from_pretrained(model_name)
26+
2627
else:
2728
model = AutoModelForCausalLM.from_pretrained(model_name)
29+
tokenizer = AutoTokenizer.from_pretrained(model_name)
30+
2831

29-
return model, model_name, tag
32+
return model, tokenizer, model_name, tag
3033

3134
def get_model(model_tag):
3235
return _get_hf_model(model_tag)
@@ -108,6 +111,13 @@ def parse_arguments():
108111
action='store_true',
109112
help='Disable double buffering of i/o buffer.')
110113

114+
parser.add_argument('--safetensors',
115+
action='store_true',
116+
help='Use safetensors load/save.')
117+
118+
parser.add_argument('--regular_torch_save',
119+
action='store_true',
120+
help='Use vanilla torch.save.')
111121

112122
#parser.add_argument('--single_writer', action='store_true', help='Disable parallel rank writes of data parallel (replicated) state')
113123

deepnvme/model_checkpoint/torch_save_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def main():
5757
if not validate_arguments(args):
5858
quit()
5959
load_io_ops(args)
60-
model, model_name, ckpt_name = get_model(args.model)
60+
model, tokenizer, model_name, ckpt_name = get_model(args.model)
6161
if args.half:
6262
model = model.half()
6363
if args.gpu:

0 commit comments

Comments
 (0)