@@ -110,13 +110,21 @@ int main(int argc, char ** argv) {
110110 return 1 ;
111111 }
112112
113- if (
114- llama_vocab_get_add_bos (vocab_tgt) != llama_vocab_get_add_bos (vocab_dft) ||
115- llama_vocab_get_add_eos (vocab_tgt) != llama_vocab_get_add_eos (vocab_dft) ||
116- llama_vocab_bos (vocab_tgt) != llama_vocab_bos (vocab_dft) ||
117- llama_vocab_eos (vocab_tgt) != llama_vocab_eos (vocab_dft)
118- ) {
119- LOG_ERR (" %s: draft model special tokens must match target model to use speculation\n " , __func__);
113+ if (llama_vocab_get_add_bos (vocab_tgt) != llama_vocab_get_add_bos (vocab_dft) ||
114+ (llama_vocab_get_add_bos (vocab_tgt) && llama_vocab_bos (vocab_tgt) != llama_vocab_bos (vocab_dft))) {
115+ LOG_ERR (" %s: draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n " ,
116+ __func__,
117+ llama_vocab_get_add_bos (vocab_tgt), llama_vocab_get_add_bos (vocab_dft),
118+ llama_vocab_bos (vocab_tgt), llama_vocab_bos (vocab_dft));
119+ return 1 ;
120+ }
121+
122+ if (llama_vocab_get_add_eos (vocab_tgt) != llama_vocab_get_add_eos (vocab_dft) ||
123+ (llama_vocab_get_add_eos (vocab_tgt) && llama_vocab_eos (vocab_tgt) != llama_vocab_eos (vocab_dft))) {
124+ LOG_ERR (" %s: draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n " ,
125+ __func__,
126+ llama_vocab_get_add_eos (vocab_tgt), llama_vocab_get_add_eos (vocab_dft),
127+ llama_vocab_eos (vocab_tgt), llama_vocab_eos (vocab_dft));
120128 return 1 ;
121129 }
122130
@@ -137,11 +145,12 @@ int main(int argc, char ** argv) {
137145 for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min (n_vocab_tgt, n_vocab_dft); ++i) {
138146 const char * token_text_tgt = llama_vocab_get_text (vocab_tgt, i);
139147 const char * token_text_dft = llama_vocab_get_text (vocab_dft, i);
148+
140149 if (std::strcmp (token_text_tgt, token_text_dft) != 0 ) {
141150 LOG_ERR (" %s: draft model vocab must match target model to use speculation but " , __func__);
142151 LOG_ERR (" token %d content differs - target '%s', draft '%s'\n " , i,
143- common_token_to_piece (ctx_tgt , i).c_str (),
144- common_token_to_piece (ctx_dft , i).c_str ());
152+ common_token_to_piece (vocab_tgt , i).c_str (),
153+ common_token_to_piece (vocab_dft , i).c_str ());
145154 return 1 ;
146155 }
147156 }
0 commit comments