@@ -223,13 +223,18 @@ struct kl_divergence_result {
223223 double sum_kld2 = 0 ;
224224 double sum_nll_diff = 0 ;
225225 double sum_nll_diff2 = 0 ;
226+ size_t n_same_top = 0 ;
226227 size_t count = 0 ;
227228};
228229
229- static void log_softmax (int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
230+ static double log_softmax (int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
230231 float max_logit = logits[0 ];
232+ int imax = 0 ;
231233 for (int i = 1 ; i < n_vocab; ++i) {
232- max_logit = std::max (max_logit, logits[i]);
234+ if (logits[i] > max_logit) {
235+ max_logit = logits[i];
236+ imax = i;
237+ }
233238 }
234239 double sum_exp = 0.0 ;
235240 for (int i = 0 ; i < n_vocab; ++i) {
@@ -248,8 +253,14 @@ static void log_softmax(int n_vocab, const float * logits, const uint16_t * base
248253 kld.sum_nll_diff2 += nll*nll;
249254 max_logit += log_sum_exp;
250255 double sum = 0 ;
256+ int imax_base = -1 ;
257+ float p_log_base_max = 0 ;
251258 for (int i = 0 ; i < n_vocab; ++i) {
252259 const float p_log_base = scale*base_log_prob[i] + min_log_prob;
260+ if (i == 0 || p_log_base > p_log_base_max) {
261+ p_log_base_max = p_log_base;
262+ imax_base = i;
263+ }
253264 if (p_log_base > -16 .f ) {
254265 const float p_base = expf (p_log_base);
255266 sum += p_base * (p_log_base - logits[i] + max_logit);
@@ -258,14 +269,17 @@ static void log_softmax(int n_vocab, const float * logits, const uint16_t * base
258269 kld.sum_kld += sum;
259270 kld.sum_kld2 += sum*sum;
260271 ++kld.count ;
272+ if (imax == imax_base) ++kld.n_same_top ;
273+ return sum;
261274}
262275
263276static void process_logits (int n_vocab, const float * logits, const int * tokens, int n_token,
264- std::vector<std::thread> & workers, const std::vector<uint16_t > & base_log_probs, kl_divergence_result & kld) {
277+ std::vector<std::thread> & workers, const std::vector<uint16_t > & base_log_probs, kl_divergence_result & kld,
278+ float * kld_values) {
265279 std::mutex mutex;
266280 const int nv = 2 *((n_vocab + 1 )/2 ) + 4 ;
267281 int counter = 0 ;
268- auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv] () {
282+ auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv, kld_values ] () {
269283 kl_divergence_result local_kld;
270284 while (true ) {
271285 std::unique_lock<std::mutex> lock (mutex);
@@ -277,11 +291,13 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens
277291 kld.sum_kld2 += local_kld.sum_kld2 ;
278292 kld.sum_nll_diff += local_kld.sum_nll_diff ;
279293 kld.sum_nll_diff2 += local_kld.sum_nll_diff2 ;
294+ kld.n_same_top += local_kld.n_same_top ;
280295 kld.count += local_kld.count ;
281296 break ;
282297 }
283298 lock.unlock ();
284- log_softmax (n_vocab, logits + i*n_vocab, base_log_probs.data () + i*nv, tokens[i+1 ], local_kld);
299+ double v = log_softmax (n_vocab, logits + i*n_vocab, base_log_probs.data () + i*nv, tokens[i+1 ], local_kld);
300+ kld_values[i] = (float )v;
285301 }
286302 };
287303 for (auto & w : workers) {
@@ -1203,11 +1219,11 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
12031219 printf (" Final Winogrande score(%d tasks): %.4lf +/- %.4lf\n " , n_done, 100 *p, sigma);
12041220}
12051221
1206- static bool deserialize_string (std::istream& in, std::string& str) {
1222+ static bool deserialize_string (std::istream & in, std::string & str) {
12071223 uint32_t size;
12081224 if (!in.read ((char *)&size, sizeof (size)).fail ()) {
12091225 str.resize (size);
1210- if (!in.read ((char *)str. data () , size).fail ()) return true ;
1226+ if (!in.read ((char *)& str[ 0 ] , size).fail ()) return true ;
12111227 }
12121228 return false ;
12131229}
@@ -1616,7 +1632,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
16161632 in.read ((char *)&n_vocab, sizeof (n_vocab));
16171633 in.read ((char *)&n_chunk, sizeof (n_chunk));
16181634 if (in.fail ()) {
1619- fprintf (stderr, " %s: failed rwading n_vocab, n_chunk from %s\n " , __func__, params.logits_file .c_str ());
1635+ fprintf (stderr, " %s: failed reading n_vocab, n_chunk from %s\n " , __func__, params.logits_file .c_str ());
16201636 return ;
16211637 }
16221638 if (n_vocab != llama_n_vocab (llama_get_model (ctx))) {
@@ -1635,6 +1651,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
16351651 const bool add_bos = llama_should_add_bos_token (llama_get_model (ctx));
16361652
16371653 std::vector<uint16_t > log_probs_uint16 (size_t (n_ctx - 1 - n_ctx/2 ) * nv);
1654+ std::vector<float > kld_values (size_t (n_ctx - 1 - n_ctx/2 )*n_chunk);
16381655 std::vector<float > logits;
16391656 if (num_batches > 1 ) {
16401657 logits.reserve (n_ctx * n_vocab);
@@ -1653,6 +1670,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
16531670 };
16541671
16551672 kl_divergence_result kld;
1673+ auto kld_ptr = kld_values.data ();
16561674
16571675 for (int i = 0 ; i < n_chunk; ++i) {
16581676 const int start = i * n_ctx;
@@ -1706,27 +1724,60 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
17061724 }
17071725 fprintf (stderr, " %.2f minutes\n " , total_seconds / 60.0 );
17081726
1709- printf (" \n chunk PPL ln(PPL(Q)/PPL(base)) KL-Divergence\n " );
1727+ printf (" \n chunk PPL ln(PPL(Q)/PPL(base)) KL-Divergence Same top \n " );
17101728 }
17111729
17121730 const int first = n_ctx/2 ;
17131731 const float * all_logits = num_batches > 1 ? logits.data () : llama_get_logits (ctx);
17141732 process_logits (n_vocab, all_logits + first*n_vocab, tokens.data () + start + first, n_ctx - 1 - first,
1715- workers, log_probs_uint16, kld);
1733+ workers, log_probs_uint16, kld, kld_ptr);
1734+ kld_ptr += n_ctx - 1 - first;
17161735
17171736 auto ppl = mean_and_uncertainty (kld.sum_nll , kld.sum_nll2 , kld.count );
17181737 auto log_ppl_ratio = mean_and_uncertainty (kld.sum_nll_diff , kld.sum_nll_diff2 , kld.count );
17191738 auto kl_div = mean_and_uncertainty (kld.sum_kld , kld.sum_kld2 , kld.count );
1739+ auto p_top = 1 .*kld.n_same_top /kld.count ;
1740+ auto d_p_top = sqrt (p_top*(1 - p_top)/(kld.count - 1 ));
17201741
1721- printf (" %4d %10.4lf %10.5lf ± %10.5f %10.5f ± %10.5lf\n " , i+1 , exp (ppl.first ),
1722- log_ppl_ratio.first , log_ppl_ratio.second , kl_div.first , kl_div.second );
1742+ printf (" %4d %10.4lf %10.5lf ± %10.5f %10.5f ± %10.5lf %.5f ± %.5f\n " , i+1 , exp (ppl.first ),
1743+ log_ppl_ratio.first , log_ppl_ratio.second , kl_div.first , kl_div.second ,
1744+ p_top, d_p_top);
17231745
17241746 fflush (stdout);
17251747
17261748 logits.clear ();
17271749 }
17281750 printf (" \n " );
17291751
1752+ if (kld.count < 100 ) return ; // we do not wish to do statistics on so few values
1753+
1754+ std::sort (kld_values.begin (), kld_values.end ());
1755+
1756+ printf (" ===== KL-divergence statistics\n " );
1757+ auto kl_div = mean_and_uncertainty (kld.sum_kld , kld.sum_kld2 , kld.count );
1758+ printf (" Average: %10.6f ±%10.6lf\n " , kl_div.first , kl_div.second );
1759+ auto kld_median = kld_values.size ()%2 == 0 ? 0 .5f *(kld_values[kld_values.size ()/2 ] + kld_values[kld_values.size ()/2 -1 ])
1760+ : kld_values[kld_values.size ()/2 ];
1761+ printf (" Median : %10.6f\n " , kld_median);
1762+
1763+ auto percentile = [&kld_values] (float fraction) {
1764+ if (fraction <= 0 ) return kld_values.front ();
1765+ if (fraction >= 1 ) return kld_values.back ();
1766+ float p = fraction*(kld_values.size () - 1 );
1767+ size_t ip = size_t (p); p -= ip;
1768+ return (1 - p)*kld_values[ip] + p*kld_values[std::min (ip+1 , kld_values.size ()-1 )];
1769+ };
1770+
1771+ printf (" Maximum: %10.6f\n " , kld_values.back ());
1772+ printf (" KLD_99 : %10.6f\n " , percentile (0 .99f ));
1773+ printf (" KLD_95 : %10.6f\n " , percentile (0 .95f ));
1774+ printf (" KLD_90 : %10.6f\n " , percentile (0 .90f ));
1775+
1776+ printf (" Minimum: %10.6f\n " , kld_values.front ());
1777+ printf (" KLD_01 : %10.6f\n " , percentile (0 .01f ));
1778+ printf (" KLD_05 : %10.6f\n " , percentile (0 .05f ));
1779+ printf (" KLD_10 : %10.6f\n " , percentile (0 .10f ));
1780+
17301781}
17311782
17321783int main (int argc, char ** argv) {
0 commit comments