@@ -697,10 +697,12 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
697697 // determine rope scaling params
698698 float rope_freq_scale = 1 .0f ;
699699 float rope_freq_base = 10000 .0f ;
700+ bool overwriteRope = false ;
700701 if (inputs.rope_freq_scale >0 .0f )
701702 {
702703 rope_freq_scale = inputs.rope_freq_scale ;
703704 rope_freq_base = inputs.rope_freq_base ;
705+ overwriteRope = true ;
704706 printf (" Using Custom RoPE scaling (scale:%.3f, base:%.1f).\n " ,rope_freq_scale,rope_freq_base);
705707 }
706708 else
@@ -722,13 +724,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
722724 rope_freq_base = (effectivenctx <= 2048 ? 10000 .0f : (effectivenctx <= 3072 ? 26000 .0f : (effectivenctx <= 4096 ? 32000 .0f : (effectivenctx <= 6144 ? 54000 .0f :
723725 (effectivenctx <= 8192 ? 82684 .0f : (effectivenctx <= 12288 ? 140000 .0f : (effectivenctx <= 16384 ? 200000 .0f : (effectivenctx <= 24576 ? 320000 .0f : 440000 .0f ))))))));
724726
725- if (file_format_meta.freq_base_train > rope_freq_base)
726- {
727- rope_freq_base = file_format_meta.freq_base_train ;
728- }
729727 }
730728
731- printf (" Using automatic RoPE scaling (scale:%.3f, base:%.1f) \n " ,rope_freq_scale,rope_freq_base );
729+ printf (" Using automatic RoPE scaling. If the model has customized RoPE settings, they will be used directly instead! \n " );
732730 }
733731 gptj_ctx_v3.hparams .rope_freq_scale = neox_ctx_v3.hparams .rope_freq_scale = rope_freq_scale;
734732 gptj_ctx_v3.hparams .rope_freq_base = neox_ctx_v3.hparams .rope_freq_base = rope_freq_base;
@@ -903,8 +901,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
903901 }
904902 #endif
905903 model_params.main_gpu = cu_parseinfo_maindevice;
906- llama_ctx_params.rope_freq_base = rope_freq_base;
907- llama_ctx_params.rope_freq_scale = rope_freq_scale;
904+
908905 llama_ctx_params.n_batch = blasbatchsize;
909906 llama_ctx_params.n_threads = n_threads;
910907 llama_ctx_params.n_threads_batch = n_blasthreads;
@@ -932,6 +929,28 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
932929 }
933930
934931 llama_model * llamamodel = llama_load_model_from_file (modelname.c_str (), model_params);
932+ if (overwriteRope)
933+ {
934+ llama_ctx_params.rope_freq_base = rope_freq_base;
935+ llama_ctx_params.rope_freq_scale = rope_freq_scale;
936+ }
937+ else
938+ {
939+ // if the model modifes rope in any way, use the model values. Otherwise, use our automatic ones
940+ if (llamamodel->hparams .rope_freq_base_train !=10000 .0f ||
941+ llamamodel->hparams .rope_freq_scale_train !=1 .0f ||
942+ llamamodel->hparams .rope_scaling_type_train ==2 )
943+ {
944+ printf (" Automatic RoPE Scaling: Using model internal values.\n " );
945+ }
946+ else
947+ {
948+ llama_ctx_params.rope_freq_base = rope_freq_base;
949+ llama_ctx_params.rope_freq_scale = rope_freq_scale;
950+ printf (" Automatic RoPE Scaling: Using (scale:%.3f, base:%.1f).\n " , rope_freq_scale, rope_freq_base);
951+ }
952+ }
953+
935954 llama_ctx_v4 = llama_new_context_with_model (llamamodel, llama_ctx_params);
936955
937956 if (llama_ctx_v4 == NULL )
0 commit comments