|
13 | 13 | #include <executorch/runtime/platform/log.h> |
14 | 14 | #include <pytorch/tokenizers/hf_tokenizer.h> |
15 | 15 |
|
| 16 | +#include <cinttypes> |
| 17 | +#include <cstdio> |
16 | 18 | #include <string> |
17 | 19 | #include <vector> |
18 | 20 |
|
@@ -67,7 +69,43 @@ int main(int argc, char** argv) { |
67 | 69 | config.temperature = FLAGS_temperature; |
68 | 70 | config.max_new_tokens = FLAGS_max_new_tokens; |
69 | 71 |
|
70 | | - auto error = runner->generate(FLAGS_prompt.c_str(), config); |
| 72 | + auto error = runner->generate( |
| 73 | + FLAGS_prompt.c_str(), |
| 74 | + config, |
| 75 | + /*token_callback=*/{}, |
| 76 | + [](const llm::Stats& stats) { |
| 77 | + double scale = stats.SCALING_FACTOR_UNITS_PER_SECOND; |
| 78 | + double model_load_s = |
| 79 | + (stats.model_load_end_ms - stats.model_load_start_ms) / scale; |
| 80 | + double inference_s = |
| 81 | + (stats.inference_end_ms - stats.inference_start_ms) / scale; |
| 82 | + double prefill_s = |
| 83 | + (stats.prompt_eval_end_ms - stats.inference_start_ms) / scale; |
| 84 | + double decode_s = |
| 85 | + (stats.inference_end_ms - stats.prompt_eval_end_ms) / scale; |
| 86 | + double ttft_s = |
| 87 | + (stats.first_token_ms - stats.inference_start_ms) / scale; |
| 88 | + double sampling_s = stats.aggregate_sampling_time_ms / scale; |
| 89 | + |
| 90 | + printf("\n\tPrompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64, |
| 91 | + stats.num_prompt_tokens, stats.num_generated_tokens); |
| 92 | + printf("\n\tModel Load Time:\t\t%f (seconds)", model_load_s); |
| 93 | + printf( |
| 94 | + "\n\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)", |
| 95 | + inference_s, stats.num_generated_tokens / inference_s); |
| 96 | + printf( |
| 97 | + "\n\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)", |
| 98 | + prefill_s, stats.num_prompt_tokens / prefill_s); |
| 99 | + printf( |
| 100 | + "\n\t\tGenerated %" PRIu64 |
| 101 | + " tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)", |
| 102 | + stats.num_generated_tokens, decode_s, |
| 103 | + stats.num_generated_tokens / decode_s); |
| 104 | + printf("\n\tTime to first generated token:\t%f (seconds)", ttft_s); |
| 105 | + printf( |
| 106 | + "\n\tSampling time over %" PRIu64 " tokens:\t%f (seconds)\n", |
| 107 | + stats.num_prompt_tokens + stats.num_generated_tokens, sampling_s); |
| 108 | + }); |
71 | 109 | if (error != executorch::runtime::Error::Ok) { |
72 | 110 | ET_LOG(Error, "Generation failed"); |
73 | 111 | return 1; |
|
0 commit comments