@@ -38,10 +38,17 @@ export MODEL_BUCKET=$4
3838# Point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you created, this bucket will store all the files generated by MaxText during a run, specifically the unscanned checkpoint.
3939export BASE_OUTPUT_DIRECTORY=$5
4040
41- export LORA_LOCAL_PATH=$6
41+ export HUGGING_FACE_CHECKPOINT=$6
42+
43+ export LORA_INPUT_ADAPTERS_PATH=$7
4244
4345export BUCKET_LOCATION=US
4446
47+ if [[ -z " HUGGING_FACE_CHECKPOINT" ]]; then
48+ echo " HUGGING_FACE_CHECKPOINT is required."
49+ exit 1
50+ fi
51+
4552# Create three GCS buckets for the demo.
4653gcloud storage buckets create ${MODEL_BUCKET} --location=${BUCKET_LOCATION} || true
4754gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY} --location=${BUCKET_LOCATION} || true
5966 # llama_or_mistral_ckpt.py requires local path, so we need to copy the checkpoint from CHKPT_BUCKET to local.
6067 tmp_ckpt_path=" /tmp/"
6168 gcloud storage cp -r ${CHKPT_BUCKET} ${tmp_ckpt_path}
69+
6270 path_parts=(${CHKPT_BUCKET// \/ / } )
6371 directory_substring=${path_parts[-1]}
6472 CONVERT_CKPT_SCRIPT=" llama_or_mistral_ckpt.py"
65- if [[ -x " ${LORA_LOCAL_PATH} " ]]; then
73+
74+ if [[ ! -z " ${LORA_INPUT_ADAPTERS_PATH} " ]]; then
75+ lora_local_path=" /tmp/"
76+
77+ if [[ " ${LORA_INPUT_ADAPTERS_PATH} " =~ ^gs:// ]]; then
78+ path_parts=(${LORA_INPUT_ADAPTERS_PATH// \/ / } )
79+ lora_dir_substring=${path_parts[-1]}
80+
81+ lora_local_path=" ${tmp_ckpt_path}${lora_dir_substring} "
82+ if [[ ! -d ${lora_local_path} ]]; then
83+ mkdir ${lora_local_path}
84+ fi
85+ gcloud storage cp -r ${LORA_INPUT_ADAPTERS_PATH} ${tmp_ckpt_path}
86+ else
87+ lora_local_path=${LORA_INPUT_ADAPTERS_PATH}
88+ fi
89+
6690 JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \
6791 --base-model-path ${tmp_ckpt_path}${directory_substring} \
6892 --maxtext-model-path ${MODEL_BUCKET} /${MODEL} /${MODEL_VARIATION} /${idx} \
6993 --model-size ${MODEL_NAME} \
70- --lora-config- path ${LORA_LOCAL_PATH} /adapter_config.json \
71- --lora-model-path ${LORA_LOCAL_PATH} /adapter_model.bin
94+ --lora-input-adapters- path ${lora_local_path} \
95+ --huggingface-checkpoint ${HUGGING_FACE_CHECKPOINT}
7296 else
7397 JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \
7498 --base-model-path ${tmp_ckpt_path}${directory_substring} \
7599 --maxtext-model-path ${MODEL_BUCKET} /${MODEL} /${MODEL_VARIATION} /${idx} \
76- --model-size ${MODEL_NAME}
100+ --model-size ${MODEL_NAME} \
101+ --huggingface-checkpoint ${HUGGING_FACE_CHECKPOINT}
77102 fi
78103fi
79104echo " Written MaxText compatible checkpoint to ${MODEL_BUCKET} /${MODEL} /${MODEL_VARIATION} /${idx} "
80105
81106# We define `SCANNED_CKPT_PATH` to refer to the checkpoint subdirectory.
82- # export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}/0/items
83107export SCANNED_CKPT_PATH=${MODEL_BUCKET} /${MODEL} /${MODEL_VARIATION} /${idx}
84108
85109# Convert MaxText compatible checkpoints to unscanned checkpoints.
86110# Note that the `SCANNED_CKPT_PATH` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
87111export RUN_NAME=${MODEL_NAME} _unscanned_chkpt_${idx}
88112
89- if [[ -x " ${LORA_LOCAL_PATH } " ]]; then
113+ if [[ ! -z " ${LORA_INPUT_ADAPTERS_PATH } " ]]; then
90114 JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \
91115 MaxText/configs/base.yml \
92116 base_output_directory=${BASE_OUTPUT_DIRECTORY} \
93- load_parameters_path=${SCANNED_CKPT_PATH} /base_weights/0/items \
94- lora_parameters_base_path=${SCANNED_CKPT_PATH} /lora_weights/0/items \
95- lora_config_path=${LORA_LOCAL_PATH} /adapter_config.json \
117+ load_parameters_path=${SCANNED_CKPT_PATH} /base/0/items \
118+ lora_input_adapters_path=${SCANNED_CKPT_PATH} /LoRAs \
96119 run_name=${RUN_NAME} \
97120 model_name=${MODEL_NAME} \
98121 force_unroll=true
101124 JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \
102125 MaxText/configs/base.yml \
103126 base_output_directory=${BASE_OUTPUT_DIRECTORY} \
104- load_parameters_path=${SCANNED_CKPT_PATH} /base_weights/ 0/items \
127+ load_parameters_path=${SCANNED_CKPT_PATH} /0/items \
105128 run_name=${RUN_NAME} \
106129 model_name=${MODEL_NAME} \
107130 force_unroll=true
0 commit comments