Skip to content

Commit e70da91

Browse files
committed
Restructure the code, add vit.h header & main entry
1 parent 708be7e commit e70da91

3 files changed

Lines changed: 296 additions & 246 deletions

File tree

main.cpp

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#define _CRT_SECURE_NO_DEPRECATE // disables "unsafe" warnings on Windows
2+
3+
#include "vit.h"
4+
#include "ggml/ggml.h"
5+
#include "ggml/ggml-alloc.h"
6+
#include "ggml/examples/stb_image.h" // stb image load
7+
8+
#include <cassert>
9+
#include <cmath>
10+
#include <cstddef>
11+
#include <cstdio>
12+
#include <cstring>
13+
#include <fstream>
14+
#include <map>
15+
#include <string>
16+
#include <vector>
17+
#include <thread>
18+
#include <cinttypes>
19+
#include <algorithm>
20+
21+
#if defined(_MSC_VER)
22+
#pragma warning(disable : 4244 4267) // possible loss of data
23+
#endif
24+
25+
// main function
26+
int main(int argc, char **argv)
27+
{
28+
const int64_t t_main_start_us = ggml_time_us();
29+
30+
vit_params params;
31+
32+
image_u8 img0;
33+
image_f32 img1;
34+
35+
vit_model model;
36+
vit_state state;
37+
std::vector<std::pair<float, int>> predictions;
38+
39+
int64_t t_load_us = 0;
40+
41+
if (vit_params_parse(argc, argv, params) == false)
42+
{
43+
return 1;
44+
}
45+
46+
if (params.seed < 0)
47+
{
48+
params.seed = time(NULL);
49+
}
50+
fprintf(stderr, "%s: seed = %d\n", __func__, params.seed);
51+
fprintf(stderr, "%s: n_threads = %d / %d\n", __func__, params.n_threads, (int32_t)std::thread::hardware_concurrency());
52+
53+
// load the model
54+
{
55+
const int64_t t_start_us = ggml_time_us();
56+
57+
if (!vit_model_load(params.model.c_str(), model))
58+
{
59+
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
60+
return 1;
61+
}
62+
63+
t_load_us = ggml_time_us() - t_start_us;
64+
}
65+
66+
// load the image
67+
if (!load_image_from_file(params.fname_inp.c_str(), img0))
68+
{
69+
fprintf(stderr, "%s: failed to load image from '%s'\n", __func__, params.fname_inp.c_str());
70+
return 1;
71+
}
72+
fprintf(stderr, "%s: loaded image '%s' (%d x %d)\n", __func__, params.fname_inp.c_str(), img0.nx, img0.ny);
73+
74+
// preprocess the image to f32
75+
if (vit_image_preprocess(img0, img1, model.hparams))
76+
{
77+
fprintf(stderr, "processed, out dims : (%d x %d)\n", img1.nx, img1.ny);
78+
}
79+
80+
// prepare for graph computation, memory allocation and results processing
81+
{
82+
static size_t buf_size = 3u * 1024 * 1024;
83+
84+
struct ggml_init_params ggml_params = {
85+
/*.mem_size =*/buf_size,
86+
/*.mem_buffer =*/NULL,
87+
/*.no_alloc =*/false,
88+
};
89+
90+
state.ctx = ggml_init(ggml_params);
91+
state.prediction = ggml_new_tensor_4d(state.ctx, GGML_TYPE_F32, model.hparams.num_classes, 1, 1, 1);
92+
93+
// printf("%s: Initialized context = %ld bytes\n", __func__, buf_size);
94+
}
95+
96+
{
97+
// run prediction on img1
98+
vit_predict(model, state, img1, params, predictions);
99+
}
100+
101+
// report timing
102+
{
103+
const int64_t t_main_end_us = ggml_time_us();
104+
fprintf(stderr, "\n\n");
105+
fprintf(stderr, "%s: model load time = %8.2f ms\n", __func__, t_load_us / 1000.0f);
106+
fprintf(stderr, "%s: processing time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us - t_load_us) / 1000.0f);
107+
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us) / 1000.0f);
108+
}
109+
110+
ggml_free(model.ctx);
111+
112+
return 0;
113+
}

0 commit comments

Comments
 (0)