Skip to content

Commit 510e764

Browse files
committed
one graph && batching
1 parent d75a28d commit 510e764

1 file changed

Lines changed: 50 additions & 34 deletions

File tree

app/Graph/acc_check.cpp

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
#ifndef WIN32_LEAN_AND_MEAN
1+
2+
#ifndef WIN32_LEAN_AND_MEAN
23
#define WIN32_LEAN_AND_MEAN
34
#endif
4-
#include <psapi.h>
55
#include <windows.h>
6+
#include <psapi.h>
67
#pragma comment(lib, "psapi.lib")
78
#include <crtdbg.h>
89
#include <algorithm>
@@ -85,7 +86,7 @@ int main(int argc, char* argv[]) {
8586
std::string model_name = "alexnet_mnist";
8687
RuntimeOptions options;
8788
size_t num_photo = 1000;
88-
size_t batch_size = 32;
89+
size_t batch_size = 50;
8990

9091
for (int i = 1; i < argc; ++i) {
9192
if (std::string(argv[i]) == "--model" && i + 1 < argc) {
@@ -151,7 +152,7 @@ int main(int argc, char* argv[]) {
151152
std::vector<int> input_shape = get_input_shape_from_json(json_path);
152153

153154
std::cout << '\n';
154-
int batch_count = 0;
155+
155156
if (model_name == "alexnet_mnist") {
156157
LOG_MEM("MNIST start");
157158

@@ -188,7 +189,7 @@ int main(int argc, char* argv[]) {
188189
for (int j = 0; j < 28; ++j) {
189190
size_t a = ind;
190191
for (size_t n = 0; n < name; n++) a += counts[n] + 1;
191-
res[(a) * 28 * 28 + i * 28 + j] = channels[0].at<uchar>(j, i);
192+
res[(a)*28 * 28 + i * 28 + j] = channels[0].at<uchar>(j, i);
192193
}
193194
}
194195
}
@@ -335,9 +336,42 @@ int main(int argc, char* argv[]) {
335336
int correct_predictions_top1 = 0;
336337
int correct_predictions_top5 = 0;
337338

339+
LOG_MEM("Building master graph");
340+
341+
it_lab_ai::Shape full_shape({num_photo, static_cast<size_t>(channels),
342+
static_cast<size_t>(height),
343+
static_cast<size_t>(width)});
344+
it_lab_ai::Tensor dummy_input = make_tensor(all_image_data, full_shape);
345+
346+
it_lab_ai::Shape full_output_shape({num_photo, output_classes});
347+
it_lab_ai::Tensor dummy_output(full_output_shape, it_lab_ai::Type::kFloat);
348+
349+
Graph graph;
350+
build_graph(graph, dummy_input, dummy_output, json_path, options, false);
351+
LOG_MEM("Master graph built");
352+
353+
std::shared_ptr<Layer> input_layer = nullptr;
354+
std::shared_ptr<Layer> output_layer = nullptr;
355+
356+
for (int i = 0; i < graph.getLayersCount(); ++i) {
357+
auto layer = graph.getLayerFromID(i);
358+
if (layer->getName() == kInput) {
359+
input_layer = layer;
360+
}
361+
if (i == graph.getLayersCount() - 1) {
362+
output_layer = layer;
363+
}
364+
}
365+
366+
if (!input_layer || !output_layer) {
367+
std::cerr << "Error: Could not find input/output layers" << '\n';
368+
return 1;
369+
}
370+
338371
LOG_MEM("Starting batch processing");
339372
auto total_start_time = std::chrono::high_resolution_clock::now();
340373
int total_inference_time = 0;
374+
int batch_count = 0;
341375

342376
for (size_t batch_start = 0; batch_start < num_photo;
343377
batch_start += batch_size) {
@@ -365,32 +399,20 @@ int main(int argc, char* argv[]) {
365399
it_lab_ai::Shape batch_output_shape({current_batch_size, output_classes});
366400
it_lab_ai::Tensor batch_output(batch_output_shape, it_lab_ai::Type::kFloat);
367401

368-
Graph graph;
369-
build_graph(graph, batch_input, batch_output, json_path, options, false);
402+
graph.setInput(input_layer, batch_input);
403+
graph.setOutput(output_layer, batch_output);
370404

371405
LOG_MEM("Batch inference");
372-
// auto batch_start_time =
373-
// std::chrono::high_resolution_clock::now();
406+
auto batch_start_time = std::chrono::high_resolution_clock::now();
374407
graph.inference(options);
375-
total_inference_time += print_time_stats(graph);
376-
// auto batch_end_time = std::chrono::high_resolution_clock::now();
377-
// int batch_time =
378-
// static_cast<int>(std::chrono::duration_cast<std::chrono::milliseconds>(
379-
// batch_end_time - batch_start_time)
380-
// .count()); // ← Добавлен static_cast
381-
// total_inference_time += batch_time;
382-
// batch_count++;
383-
384-
// #ifdef ENABLE_STATISTIC_TIME
385-
// std::vector<int> elps_time = graph.getTime();
386-
// int batch_time = std::accumulate(elps_time.begin(),
387-
// elps_time.end(), 0); total_inference_time += batch_time;
388-
// batch_count++;
389-
//
390-
// char time_log[100];
391-
// sprintf(time_log, "Batch %d time: %d ms", batch_count,
392-
// batch_time); LOG_MEM(time_log);
393-
// #endif
408+
auto batch_end_time = std::chrono::high_resolution_clock::now();
409+
410+
int batch_time =
411+
static_cast<int>(std::chrono::duration_cast<std::chrono::milliseconds>(
412+
batch_end_time - batch_start_time)
413+
.count());
414+
total_inference_time += batch_time;
415+
batch_count++;
394416

395417
const std::vector<float>& raw_batch_output = *batch_output.as<float>();
396418

@@ -451,12 +473,6 @@ int main(int argc, char* argv[]) {
451473
<< (batch_count > 0 ? total_inference_time / batch_count : 0)
452474
<< " ms\n";
453475
std::cout << "!INFERENCE TIME INFO END!" << '\n';
454-
/*std::cout << "\n!INFERENCE TIME INFO START!" << '\n';
455-
std::cout << "Total inference time for all batches: " << total_inference_time
456-
<< " ms\n";
457-
std::cout << "Number of batches: " << batch_count << '\n';
458-
std::cout << "!INFERENCE TIME INFO END!" << '\n';
459-
LOG_MEM("All batches processed");*/
460476

461477
double final_accuracy_top1 =
462478
(static_cast<double>(correct_predictions_top1) / num_photo) * 100;

0 commit comments

Comments
 (0)