Skip to content

Commit d343d49

Browse files
committed
Update on "add parakeet into cuda benckmark ci"
as title Differential Revision: [D92208958](https://our.internmc.facebook.com/intern/diff/D92208958/) [ghstack-poisoned]
1 parent 25cba56 commit d343d49

1 file changed

Lines changed: 18 additions & 0 deletions

File tree

examples/models/parakeet/main.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "types.h"
2727

2828
#include <executorch/extension/llm/runner/llm_runner_helper.h>
29+
#include <executorch/extension/llm/runner/stats.h>
2930
#include <executorch/extension/llm/runner/util.h>
3031
#include <executorch/extension/llm/runner/wav_loader.h>
3132
#include <executorch/extension/llm/tokenizers/third-party/llama.cpp-unicode/include/unicode.h>
@@ -334,6 +335,10 @@ std::vector<Token> greedy_decode_executorch(
334335
int main(int argc, char** argv) {
335336
gflags::ParseCommandLineFlags(&argc, &argv, true);
336337

338+
// Initialize stats for benchmarking
339+
::executorch::extension::llm::Stats stats;
340+
stats.model_load_start_ms = ::executorch::extension::llm::time_in_ms();
341+
337342
TimestampOutputMode timestamp_mode;
338343
try {
339344
timestamp_mode = parse_timestamp_output_mode(FLAGS_timestamps);
@@ -362,6 +367,8 @@ int main(int argc, char** argv) {
362367
ET_LOG(Error, "Failed to load model.");
363368
return 1;
364369
}
370+
stats.model_load_end_ms = ::executorch::extension::llm::time_in_ms();
371+
stats.inference_start_ms = ::executorch::extension::llm::time_in_ms();
365372

366373
// Load audio
367374
ET_LOG(Info, "Loading audio from: %s", FLAGS_audio_path.c_str());
@@ -465,11 +472,14 @@ int main(int argc, char** argv) {
465472
window_stride,
466473
encoder_subsampling_factor);
467474

475+
stats.prompt_eval_end_ms = ::executorch::extension::llm::time_in_ms();
476+
468477
ET_LOG(Info, "Running TDT greedy decode...");
469478
auto decoded_tokens = greedy_decode_executorch(
470479
*model, f_proj, encoded_len, blank_id, num_rnn_layers, pred_hidden);
471480

472481
ET_LOG(Info, "Decoded %zu tokens", decoded_tokens.size());
482+
stats.first_token_ms = stats.prompt_eval_end_ms; // For ASR, first token is at end of encoding
473483

474484
// Load tokenizer
475485
ET_LOG(Info, "Loading tokenizer from: %s", FLAGS_tokenizer_path.c_str());
@@ -488,6 +498,14 @@ int main(int argc, char** argv) {
488498
decoded_tokens, *tokenizer);
489499
std::cout << "Transcribed text: " << text << std::endl;
490500

501+
// Record inference end time and token counts
502+
stats.inference_end_ms = ::executorch::extension::llm::time_in_ms();
503+
stats.num_prompt_tokens = encoded_len; // Use encoder output length as "prompt" tokens
504+
stats.num_generated_tokens = static_cast<int64_t>(decoded_tokens.size());
505+
506+
// Print PyTorchObserver stats for benchmarking
507+
::executorch::extension::llm::print_report(stats);
508+
491509
#ifdef ET_BUILD_METAL
492510
executorch::backends::metal::print_metal_backend_stats();
493511
#endif // ET_BUILD_METAL

0 commit comments

Comments
 (0)