@@ -37,38 +37,48 @@ using json = nlohmann::ordered_json;
3737
3838constexpr int HTTP_POLLING_SECONDS = 1 ;
3939
40- static uint32_t server_n_outputs_per_seq (const common_params_speculative & speculative) {
41- uint32_t n_outputs = 1 ;
40+ static uint32_t server_n_outputs_max (const common_params & params) {
41+ const uint32_t n_batch = params.n_batch ;
42+ const uint32_t n_ubatch = std::min (n_batch, params.n_ubatch == 0 ? n_batch : params.n_ubatch );
4243
43- for (const auto type : speculative.types ) {
44+ if (params.embedding ||
45+ (params.pooling_type != LLAMA_POOLING_TYPE_UNSPECIFIED && params.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
46+ return n_ubatch;
47+ }
48+
49+ uint32_t n_outputs_per_seq = 1 ;
50+
51+ for (const auto type : params.speculative .types ) {
4452 switch (type) {
4553 case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE:
4654 case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3:
4755 case COMMON_SPECULATIVE_TYPE_DRAFT_MTP:
48- n_outputs = std::max<uint32_t >(n_outputs , 1 + std::max (0 , speculative.draft .n_max ));
56+ n_outputs_per_seq = std::max<uint32_t >(n_outputs_per_seq , 1 + std::max (0 , params. speculative .draft .n_max ));
4957 break ;
5058 case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE:
51- n_outputs = std::max<uint32_t >(n_outputs , 1 + speculative.ngram_simple .size_m );
59+ n_outputs_per_seq = std::max<uint32_t >(n_outputs_per_seq , 1 + params. speculative .ngram_simple .size_m );
5260 break ;
5361 case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:
54- n_outputs = std::max<uint32_t >(n_outputs , 1 + speculative.ngram_map_k .size_m );
62+ n_outputs_per_seq = std::max<uint32_t >(n_outputs_per_seq , 1 + params. speculative .ngram_map_k .size_m );
5563 break ;
5664 case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V:
57- n_outputs = std::max<uint32_t >(n_outputs , 1 + speculative.ngram_map_k4v .size_m );
65+ n_outputs_per_seq = std::max<uint32_t >(n_outputs_per_seq , 1 + params. speculative .ngram_map_k4v .size_m );
5866 break ;
5967 case COMMON_SPECULATIVE_TYPE_NGRAM_MOD:
60- n_outputs = std::max<uint32_t >(n_outputs , 1 + std::max (0 , speculative.ngram_mod .n_max ));
68+ n_outputs_per_seq = std::max<uint32_t >(n_outputs_per_seq , 1 + std::max (0 , params. speculative .ngram_mod .n_max ));
6169 break ;
6270 case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE:
63- n_outputs = std::max<uint32_t >(n_outputs , 1 + 8 );
71+ n_outputs_per_seq = std::max<uint32_t >(n_outputs_per_seq , 1 + 8 );
6472 break ;
6573 case COMMON_SPECULATIVE_TYPE_NONE:
6674 case COMMON_SPECULATIVE_TYPE_COUNT:
6775 break ;
6876 }
6977 }
7078
71- return n_outputs;
79+ const uint64_t n_outputs = (uint64_t ) params.n_parallel * n_outputs_per_seq;
80+
81+ return std::max<uint32_t >(1 , std::min<uint64_t >(n_ubatch, n_outputs));
7282}
7383
7484// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283
@@ -787,7 +797,7 @@ struct server_context_impl {
787797 SRV_INF (" loading model '%s'\n " , params.model .path .c_str ());
788798
789799 params_base = params;
790- params_base.n_outputs_per_seq = server_n_outputs_per_seq (params_base. speculative );
800+ params_base.n_outputs_max = server_n_outputs_max (params_base);
791801
792802 std::string & mmproj_path = params_base.mmproj .path ;
793803 bool has_mmproj = !mmproj_path.empty ();
@@ -854,7 +864,7 @@ struct server_context_impl {
854864 }
855865
856866 if (!has_draft) {
857- params_dft.n_outputs_per_seq = 1 ;
867+ params_dft.n_outputs_max = params_base. n_parallel ;
858868 }
859869
860870 auto mparams_dft = common_model_params_to_llama (params_dft);
@@ -980,11 +990,11 @@ struct server_context_impl {
980990 params_base.model .path .c_str ());
981991
982992 auto cparams_mtp = common_context_params_to_llama (params_base);
983- cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
984- cparams_mtp.type_k = params_base.speculative .draft .cache_type_k ;
985- cparams_mtp.type_v = params_base.speculative .draft .cache_type_v ;
986- cparams_mtp.n_rs_seq = 0 ;
987- cparams_mtp.n_outputs_per_seq = 1 ;
993+ cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
994+ cparams_mtp.type_k = params_base.speculative .draft .cache_type_k ;
995+ cparams_mtp.type_v = params_base.speculative .draft .cache_type_v ;
996+ cparams_mtp.n_rs_seq = 0 ;
997+ cparams_mtp.n_outputs_max = params_base. n_parallel ;
988998
989999 ctx_dft.reset (llama_init_from_model (model_tgt, cparams_mtp));
9901000 if (ctx_dft == nullptr ) {
0 commit comments