Skip to content

Commit 3b30f12

Browse files
committed
future proof handling of rnn models
1 parent 7857578 commit 3b30f12

2 files changed

Lines changed: 19 additions & 5 deletions

File tree

gpttype_adapter.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,15 @@ void ContextRewind(std::vector<int> &embd, std::vector<int> &current_context_tok
484484
printf("\nWARNING: Don't use context rewind when in batch processing phase!\n");
485485
return;
486486
}
487-
bool is_recurrent = (file_format == FileFormat::GGUF_GENERIC && (file_format_meta.model_architecture==GGUFArch::ARCH_MAMBALIKE
488-
|| file_format_meta.model_architecture==GGUFArch::ARCH_RWKV));
487+
bool is_recurrent = false;
488+
if(file_format==FileFormat::GGUF_GENERIC)
489+
{
490+
const llama_model * mdl = llama_get_model(llama_ctx_v4);
491+
if(llama_model_is_recurrent(mdl) || llama_model_is_hybrid(mdl))
492+
{
493+
is_recurrent = true;
494+
}
495+
}
489496
if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2 || is_recurrent)
490497
{
491498
printf("\nWARNING: RNN models do not support context rewind!\n");
@@ -3747,8 +3754,15 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
37473754
printf("%s\n", RemoveBell(outstr).c_str());
37483755
}
37493756

3750-
bool is_recurrent = (file_format == FileFormat::GGUF_GENERIC && (file_format_meta.model_architecture==GGUFArch::ARCH_MAMBALIKE
3751-
|| file_format_meta.model_architecture==GGUFArch::ARCH_RWKV));
3757+
bool is_recurrent = false;
3758+
if(file_format==FileFormat::GGUF_GENERIC)
3759+
{
3760+
const llama_model * mdl = llama_get_model(llama_ctx_v4);
3761+
if(llama_model_is_recurrent(mdl) || llama_model_is_hybrid(mdl))
3762+
{
3763+
is_recurrent = true;
3764+
}
3765+
}
37523766
bool blank_prompt = (addedmemory=="" && kcpp_data->prompt=="");
37533767

37543768
if (file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2 || is_recurrent)

model_adapter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ std::string gguf_get_model_arch(const std::string & gguf_filename)
368368
fileformatmeta->model_architecture = GGUFArch::ARCH_FALCON;
369369
}
370370
else if(modelarch=="mamba" || modelarch=="mamba2" || modelarch=="nemotron_h" || modelarch=="jamba" || modelarch=="granitehybrid" || modelarch=="lfm2"
371-
|| modelarch=="plamo2" || modelarch=="falcon-h1") //lazy approach, put all RNN models
371+
|| modelarch=="plamo2" || modelarch=="falcon-h1") //lazy approach, put all non rwkv RNN models
372372
{
373373
fileformatmeta->model_architecture = GGUFArch::ARCH_MAMBALIKE;
374374
}

0 commit comments

Comments
 (0)