@@ -157,7 +157,7 @@ static void process_logits(std::ostream& out, int n_vocab, const float * logits,
157157 break ;
158158 }
159159 lock.unlock ();
160- const double v = log_softmax (n_vocab, logits + size_t (i)*n_vocab, log_probs.data () + i *nv, tokens[i+1 ]);
160+ const double v = log_softmax (n_vocab, logits + size_t (i)*n_vocab, log_probs.data () + size_t (i) *nv, tokens[i+1 ]);
161161 local_nll += v;
162162 local_nll2 += v*v;
163163 }
@@ -169,7 +169,7 @@ static void process_logits(std::ostream& out, int n_vocab, const float * logits,
169169 for (auto & w : workers) {
170170 w.join ();
171171 }
172- out.write ((const char *)log_probs.data (), n_token*nv*sizeof (uint16_t ));
172+ out.write ((const char *)log_probs.data (), size_t ( n_token) *nv*sizeof (uint16_t ));
173173}
174174
175175struct kl_divergence_result {
@@ -279,7 +279,7 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens
279279 break ;
280280 }
281281 lock.unlock ();
282- std::pair<double , float > v = log_softmax (n_vocab, logits + size_t (i)*n_vocab, base_log_probs.data () + i *nv, tokens[i+1 ], local_kld);
282+ std::pair<double , float > v = log_softmax (n_vocab, logits + size_t (i)*n_vocab, base_log_probs.data () + size_t (i) *nv, tokens[i+1 ], local_kld);
283283 kld_values[i] = (float )v.first ;
284284 p_diff_values[i] = v.second ;
285285 }
0 commit comments