Skip to content

Commit fc7fe2e

Browse files
committed
allow rwkv6 to run although its broken
1 parent b631580 commit fc7fe2e

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

gpttype_adapter.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -194,15 +194,18 @@ static void TokenizeString(const std::string & str_to_tokenize, std::vector<int>
194194
if(add_bos)
195195
{
196196
llama_token bostoadd = llama_token_bos(&(llama_ctx_v4->model));
197-
if(output_tokens.size()==0)
197+
if(bostoadd != LLAMA_TOKEN_NULL) //if bos does not exist, do not add it
198198
{
199-
output_tokens.push_back(bostoadd);
200-
}
201-
else
202-
{
203-
if(output_tokens[0]!=bostoadd)
199+
if(output_tokens.size()==0)
204200
{
205-
output_tokens.insert(output_tokens.begin(), 1, bostoadd);
201+
output_tokens.push_back(bostoadd);
202+
}
203+
else
204+
{
205+
if(output_tokens[0]!=bostoadd)
206+
{
207+
output_tokens.insert(output_tokens.begin(), 1, bostoadd);
208+
}
206209
}
207210
}
208211
}
@@ -1870,6 +1873,12 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
18701873
}
18711874
}
18721875

1876+
if(file_format_meta.model_architecture==GGUFArch::ARCH_RWKV)
1877+
{
1878+
printf("\nRWKV6 Overriding EOS and BOS IDs to 0\n");
1879+
llamamodel->vocab.special_bos_id = llamamodel->vocab.special_eos_id = 0;
1880+
}
1881+
18731882
llama_ctx_params.flash_attn = kcpp_params->flash_attn;
18741883
llama_ctx_params.type_k = (inputs.quant_k>1?GGML_TYPE_Q4_0:(inputs.quant_k==1?GGML_TYPE_Q8_0:GGML_TYPE_F16));
18751884
llama_ctx_params.type_v = (inputs.quant_v>1?GGML_TYPE_Q4_0:(inputs.quant_v==1?GGML_TYPE_Q8_0:GGML_TYPE_F16));
@@ -3085,7 +3094,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
30853094
if (!inputs.allow_eos_token && !inputs.bypass_eos_token)
30863095
{
30873096
// set the logit of the eos token to very low to avoid sampling it
3088-
logitsPtr[eosID] = lowestLogit;
3097+
if(eosID!=LLAMA_TOKEN_NULL)
3098+
{
3099+
logitsPtr[eosID] = lowestLogit;
3100+
}
30893101
if(eotID!=-1)
30903102
{
30913103
logitsPtr[eotID] = lowestLogit;

model_adapter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,10 @@ void print_tok_vec(std::vector<float> &embd)
314314
{
315315
fileformatmeta->model_architecture = GGUFArch::ARCH_QWEN2;
316316
}
317+
else if(modelarch=="rwkv6")
318+
{
319+
fileformatmeta->model_architecture = GGUFArch::ARCH_RWKV;
320+
}
317321
printf("Arch Category: %d\n",fileformatmeta->model_architecture);
318322

319323
}

model_adapter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ enum GGUFArch
5858
ARCH_MAMBA = 3,
5959
ARCH_SOLAR = 4,
6060
ARCH_QWEN2 = 5,
61+
ARCH_RWKV = 6,
6162
};
6263

6364
struct FileFormatExtraMeta

0 commit comments

Comments
 (0)