Skip to content

Commit febaed1

Browse files
committed
1) Adding more comments at applying LoRA on Prefill params path.
2) Fixing model_ckpt_conversion.sh after refactoring and merging from main.
1 parent 26b1f37 commit febaed1

2 files changed

Lines changed: 46 additions & 11 deletions

File tree

jetstream/core/orchestrator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,18 @@ def _prefill_thread(self, idx: int):
666666

667667
adapter_id = request.adapter_id
668668

669+
# As prefill is happening one prompt at a time, for each prefill, we are
670+
# applying the LoRA params on base to create a copy of params (equivalent
671+
# to the size of base params) and use that for generating kv-cache. This
672+
# copy is called the final_prefill_params, which is deleted soon after the
673+
# generation of kv-cache.
674+
# We can have memory-optimizations by updating the original copy of
675+
# base params at the cost of extra computations to revert it back to original
676+
# base params after kv-cache computation of each prompt, so that it can
677+
# be used by the next prompt. But this optimization could also be tricky
678+
# because as of now same params are being shared by prefill and generate,
679+
# where generate always expect the base_params. So some race conditions need
680+
# to be avoided.
669681
final_prefill_params = None
670682
if adapter_id == "":
671683
final_prefill_params = prefill_params

jetstream/tools/maxtext/model_ckpt_conversion.sh

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,17 @@ export MODEL_BUCKET=$4
3838
# Point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you created, this bucket will store all the files generated by MaxText during a run, specifically the unscanned checkpoint.
3939
export BASE_OUTPUT_DIRECTORY=$5
4040

41-
export LORA_LOCAL_PATH=$6
41+
export HUGGING_FACE_CHECKPOINT=$6
42+
43+
export LORA_INPUT_ADAPTERS_PATH=$7
4244

4345
export BUCKET_LOCATION=US
4446

47+
if [[ -z "HUGGING_FACE_CHECKPOINT" ]]; then
48+
echo "HUGGING_FACE_CHECKPOINT is required."
49+
exit 1
50+
fi
51+
4552
# Create three GCS buckets for the demo.
4653
gcloud storage buckets create ${MODEL_BUCKET} --location=${BUCKET_LOCATION} || true
4754
gcloud storage buckets create ${BASE_OUTPUT_DIRECTORY} --location=${BUCKET_LOCATION} || true
@@ -59,40 +66,56 @@ else
5966
# llama_or_mistral_ckpt.py requires local path, so we need to copy the checkpoint from CHKPT_BUCKET to local.
6067
tmp_ckpt_path="/tmp/"
6168
gcloud storage cp -r ${CHKPT_BUCKET} ${tmp_ckpt_path}
69+
6270
path_parts=(${CHKPT_BUCKET//\// })
6371
directory_substring=${path_parts[-1]}
6472
CONVERT_CKPT_SCRIPT="llama_or_mistral_ckpt.py"
65-
if [[ -x "${LORA_LOCAL_PATH}" ]]; then
73+
74+
if [[ ! -z "${LORA_INPUT_ADAPTERS_PATH}" ]]; then
75+
lora_local_path="/tmp/"
76+
77+
if [[ "${LORA_INPUT_ADAPTERS_PATH}" =~ ^gs:// ]]; then
78+
path_parts=(${LORA_INPUT_ADAPTERS_PATH//\// })
79+
lora_dir_substring=${path_parts[-1]}
80+
81+
lora_local_path="${tmp_ckpt_path}${lora_dir_substring}"
82+
if [[ ! -d ${lora_local_path} ]]; then
83+
mkdir ${lora_local_path}
84+
fi
85+
gcloud storage cp -r ${LORA_INPUT_ADAPTERS_PATH} ${tmp_ckpt_path}
86+
else
87+
lora_local_path=${LORA_INPUT_ADAPTERS_PATH}
88+
fi
89+
6690
JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \
6791
--base-model-path ${tmp_ckpt_path}${directory_substring} \
6892
--maxtext-model-path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \
6993
--model-size ${MODEL_NAME} \
70-
--lora-config-path ${LORA_LOCAL_PATH}/adapter_config.json \
71-
--lora-model-path ${LORA_LOCAL_PATH}/adapter_model.bin
94+
--lora-input-adapters-path ${lora_local_path} \
95+
--huggingface-checkpoint ${HUGGING_FACE_CHECKPOINT}
7296
else
7397
JAX_PLATFORMS=cpu python MaxText/${CONVERT_CKPT_SCRIPT} \
7498
--base-model-path ${tmp_ckpt_path}${directory_substring} \
7599
--maxtext-model-path ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx} \
76-
--model-size ${MODEL_NAME}
100+
--model-size ${MODEL_NAME} \
101+
--huggingface-checkpoint ${HUGGING_FACE_CHECKPOINT}
77102
fi
78103
fi
79104
echo "Written MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}"
80105

81106
# We define `SCANNED_CKPT_PATH` to refer to the checkpoint subdirectory.
82-
# export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}/0/items
83107
export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL}/${MODEL_VARIATION}/${idx}
84108

85109
# Convert MaxText compatible checkpoints to unscanned checkpoints.
86110
# Note that the `SCANNED_CKPT_PATH` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
87111
export RUN_NAME=${MODEL_NAME}_unscanned_chkpt_${idx}
88112

89-
if [[ -x "${LORA_LOCAL_PATH}" ]]; then
113+
if [[ ! -z "${LORA_INPUT_ADAPTERS_PATH}" ]]; then
90114
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \
91115
MaxText/configs/base.yml \
92116
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
93-
load_parameters_path=${SCANNED_CKPT_PATH}/base_weights/0/items \
94-
lora_parameters_base_path=${SCANNED_CKPT_PATH}/lora_weights/0/items \
95-
lora_config_path=${LORA_LOCAL_PATH}/adapter_config.json \
117+
load_parameters_path=${SCANNED_CKPT_PATH}/base/0/items \
118+
lora_input_adapters_path=${SCANNED_CKPT_PATH}/LoRAs \
96119
run_name=${RUN_NAME} \
97120
model_name=${MODEL_NAME} \
98121
force_unroll=true
@@ -101,7 +124,7 @@ else
101124
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py \
102125
MaxText/configs/base.yml \
103126
base_output_directory=${BASE_OUTPUT_DIRECTORY} \
104-
load_parameters_path=${SCANNED_CKPT_PATH}/base_weights/0/items \
127+
load_parameters_path=${SCANNED_CKPT_PATH}/0/items \
105128
run_name=${RUN_NAME} \
106129
model_name=${MODEL_NAME} \
107130
force_unroll=true

0 commit comments

Comments
 (0)