From e1c100329029575e006d225fd9eb043fcb3b6c27 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Tue, 28 Apr 2026 08:56:45 -0700 Subject: [PATCH 1/2] Add structured stats reporting and GPU memory tracking to Qwen3.5 MoE runner Runner now uses llm::Stats with proper timestamps for model load, prefill, decode, and GPU memory (via cudaMemGetInfo). Output matches stats.h print_report format: PyTorchObserver JSON line plus human-readable table. This commit was authored with the assistance of Claude Code. [ghstack-poisoned] --- examples/models/qwen3_5_moe/main.cpp | 130 +++++++++++++++++++++++---- 1 file changed, 114 insertions(+), 16 deletions(-) diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index c5024890645..0dd49280ef3 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -18,6 +18,7 @@ #include #include +#include #include #include #include @@ -110,6 +111,17 @@ int main(int argc, char** argv) { return 1; } + // GPU memory: before load + { + size_t free = 0, total = 0; + if (cudaMemGetInfo(&free, &total) == cudaSuccess) { + stats.gpu_total_bytes = total; + stats.gpu_free_before_load_bytes = free; + } + } + + stats.model_load_start_ms = llm::time_in_ms(); + // Create Module with share_memory_arenas=true so prefill and decode // share mutable buffers (KV cache, conv_state, recurrent_state). std::vector data_files; @@ -184,11 +196,13 @@ int main(int argc, char** argv) { stats.model_load_end_ms = llm::time_in_ms(); -#ifdef EXECUTORCH_BUILD_CUDA - // GPU memory after load - cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); - stats.gpu_free_after_load_bytes = gpu_free_bytes; -#endif + // GPU memory: after load + { + size_t free = 0, total = 0; + if (cudaMemGetInfo(&free, &total) == cudaSuccess) { + stats.gpu_free_after_load_bytes = free; + } + } // Get EOS ids auto eos_ids = llm::get_eos_ids(tokenizer.get(), module.get()); @@ -231,6 +245,9 @@ int main(int argc, char** argv) { auto temp_tensor = from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float); + stats.inference_start_ms = llm::time_in_ms(); + stats.num_prompt_tokens = num_prompt_tokens; + // --------------------------------------------------------------- // Prefill // --------------------------------------------------------------- @@ -272,14 +289,14 @@ int main(int argc, char** argv) { cur_token = read_token(prefill_outputs[0].toTensor()); stats.prompt_eval_end_ms = llm::time_in_ms(); - + stats.first_token_ms = stats.prompt_eval_end_ms; double prefill_ms = (double)(stats.prompt_eval_end_ms - stats.inference_start_ms); printf( "Prefill: %" PRId64 " tokens in %.1f ms (%.1f tok/s)\n", num_prompt_tokens, prefill_ms, - num_prompt_tokens * 1000.0 / prefill_ms); + num_prompt_tokens / prefill_ms * stats.SCALING_FACTOR_UNITS_PER_SECOND); #ifdef EXECUTORCH_BUILD_CUDA // Synchronize CUDA device to ensure prefill's writes to shared mutable @@ -344,24 +361,105 @@ int main(int argc, char** argv) { int64_t num_generated = pos - num_prompt_tokens; stats.num_generated_tokens = num_generated; + // GPU memory: after generate + peak usage + { + size_t free = 0, total = 0; + if (cudaMemGetInfo(&free, &total) == cudaSuccess) { + stats.gpu_free_after_generate_bytes = free; + size_t min_free = free; + if (stats.gpu_free_before_load_bytes != static_cast(-1)) { + min_free = std::min(min_free, (size_t)stats.gpu_free_before_load_bytes); + } + if (stats.gpu_free_after_load_bytes != static_cast(-1)) { + min_free = std::min(min_free, (size_t)stats.gpu_free_after_load_bytes); + } + stats.gpu_peak_usage_mb = (double)(total - min_free) / 1024.0 / 1024.0; + } + } + + printf("\n"); + double decode_ms = (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); + printf( + "Prefill: %" PRId64 " tokens in %.1f ms (%.1f tok/s)\n", + num_prompt_tokens, + prefill_ms, + num_prompt_tokens / prefill_ms * stats.SCALING_FACTOR_UNITS_PER_SECOND); printf( "Decode: %" PRId64 " tokens in %.1f ms (%.1f tok/s)\n", num_generated, decode_ms, - num_generated * 1000.0 / decode_ms); + num_generated / decode_ms * stats.SCALING_FACTOR_UNITS_PER_SECOND); printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens); -#ifdef EXECUTORCH_BUILD_CUDA - // GPU memory after generation - cudaMemGetInfo(&gpu_free_bytes, &gpu_total_bytes); - stats.gpu_free_after_generate_bytes = gpu_free_bytes; - stats.gpu_peak_usage_mb = - (stats.gpu_total_bytes - gpu_free_bytes) / 1024.0 / 1024.0; -#endif + // Structured stats report (matches stats.h print_report) + printf("PyTorchObserver %s\n", llm::stats_to_json_string(stats).c_str()); + + double ms_per_s = stats.SCALING_FACTOR_UNITS_PER_SECOND; + + double model_load_s = + (double)(stats.model_load_end_ms - stats.model_load_start_ms) / ms_per_s; + double inference_time_ms = + (double)(stats.inference_end_ms - stats.inference_start_ms); + double prompt_eval_ms = + (double)(stats.prompt_eval_end_ms - stats.inference_start_ms); + double eval_ms = + (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); + double ttft_s = + (double)(stats.first_token_ms - stats.inference_start_ms) / ms_per_s; + double sampling_s = (double)stats.aggregate_sampling_time_ms / ms_per_s; - llm::print_report(stats); + printf("\n"); + printf( + "\tPrompt Tokens: %" PRId64 " Generated Tokens: %" PRId64 "\n", + stats.num_prompt_tokens, + stats.num_generated_tokens); + printf("\tModel Load Time:\t\t%f (seconds)\n", model_load_s); + printf( + "\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)\n", + inference_time_ms / ms_per_s, + stats.num_generated_tokens / inference_time_ms * ms_per_s); + printf( + "\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)\n", + prompt_eval_ms / ms_per_s, + stats.num_prompt_tokens / prompt_eval_ms * ms_per_s); + printf( + "\t\tGenerated %" PRId64 + " tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)\n", + stats.num_generated_tokens, + eval_ms / ms_per_s, + stats.num_generated_tokens / eval_ms * ms_per_s); + printf("\tTime to first generated token:\t%f (seconds)\n", ttft_s); + printf( + "\tSampling time over %" PRId64 " tokens:\t%f (seconds)\n", + stats.num_prompt_tokens + stats.num_generated_tokens, + sampling_s); + + // GPU memory reporting + if (stats.gpu_total_bytes != static_cast(-1)) { + printf( + "\tGPU total memory: %.2f MB\n", + stats.gpu_total_bytes / 1024.0 / 1024.0); + if (stats.gpu_free_before_load_bytes != static_cast(-1)) { + printf( + "\tGPU free before load: %.2f MB\n", + stats.gpu_free_before_load_bytes / 1024.0 / 1024.0); + } + if (stats.gpu_free_after_load_bytes != static_cast(-1)) { + printf( + "\tGPU free after load: %.2f MB\n", + stats.gpu_free_after_load_bytes / 1024.0 / 1024.0); + } + if (stats.gpu_free_after_generate_bytes != static_cast(-1)) { + printf( + "\tGPU free after generate: %.2f MB\n", + stats.gpu_free_after_generate_bytes / 1024.0 / 1024.0); + } + if (stats.gpu_peak_usage_mb >= 0.0) { + printf("\tGPU peak usage: %.2f MB\n", stats.gpu_peak_usage_mb); + } + } return 0; } From ca2f5dfcd47e4156a4d6b01e37e8b387c21a01c2 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Tue, 28 Apr 2026 14:01:51 -0700 Subject: [PATCH 2/2] Update on "Add structured stats reporting and GPU memory tracking to Qwen3.5 MoE runner" Runner now uses llm::Stats with proper timestamps for model load, prefill, decode, and GPU memory (via cudaMemGetInfo). Output matches stats.h print_report format: PyTorchObserver JSON line plus human-readable table. This commit was authored with the assistance of Claude Code. [ghstack-poisoned] --- examples/models/qwen3_5_moe/main.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/models/qwen3_5_moe/main.cpp b/examples/models/qwen3_5_moe/main.cpp index 0dd49280ef3..5aca091fc16 100644 --- a/examples/models/qwen3_5_moe/main.cpp +++ b/examples/models/qwen3_5_moe/main.cpp @@ -404,8 +404,7 @@ int main(int argc, char** argv) { (double)(stats.inference_end_ms - stats.inference_start_ms); double prompt_eval_ms = (double)(stats.prompt_eval_end_ms - stats.inference_start_ms); - double eval_ms = - (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); + double eval_ms = (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); double ttft_s = (double)(stats.first_token_ms - stats.inference_start_ms) / ms_per_s; double sampling_s = (double)stats.aggregate_sampling_time_ms / ms_per_s;