1+ #define _CRT_SECURE_NO_DEPRECATE // disables "unsafe" warnings on Windows
2+
3+ #include " vit.h"
4+ #include " ggml/ggml.h"
5+ #include " ggml/ggml-alloc.h"
6+ #include " ggml/examples/stb_image.h" // stb image load
7+
8+ #include < cassert>
9+ #include < cmath>
10+ #include < cstddef>
11+ #include < cstdio>
12+ #include < cstring>
13+ #include < fstream>
14+ #include < map>
15+ #include < string>
16+ #include < vector>
17+ #include < thread>
18+ #include < cinttypes>
19+ #include < algorithm>
20+
21+ #if defined(_MSC_VER)
22+ #pragma warning(disable : 4244 4267) // possible loss of data
23+ #endif
24+
25+ // main function
26+ int main (int argc, char **argv)
27+ {
28+ const int64_t t_main_start_us = ggml_time_us ();
29+
30+ vit_params params;
31+
32+ image_u8 img0;
33+ image_f32 img1;
34+
35+ vit_model model;
36+ vit_state state;
37+ std::vector<std::pair<float , int >> predictions;
38+
39+ int64_t t_load_us = 0 ;
40+
41+ if (vit_params_parse (argc, argv, params) == false )
42+ {
43+ return 1 ;
44+ }
45+
46+ if (params.seed < 0 )
47+ {
48+ params.seed = time (NULL );
49+ }
50+ fprintf (stderr, " %s: seed = %d\n " , __func__, params.seed );
51+ fprintf (stderr, " %s: n_threads = %d / %d\n " , __func__, params.n_threads , (int32_t )std::thread::hardware_concurrency ());
52+
53+ // load the model
54+ {
55+ const int64_t t_start_us = ggml_time_us ();
56+
57+ if (!vit_model_load (params.model .c_str (), model))
58+ {
59+ fprintf (stderr, " %s: failed to load model from '%s'\n " , __func__, params.model .c_str ());
60+ return 1 ;
61+ }
62+
63+ t_load_us = ggml_time_us () - t_start_us;
64+ }
65+
66+ // load the image
67+ if (!load_image_from_file (params.fname_inp .c_str (), img0))
68+ {
69+ fprintf (stderr, " %s: failed to load image from '%s'\n " , __func__, params.fname_inp .c_str ());
70+ return 1 ;
71+ }
72+ fprintf (stderr, " %s: loaded image '%s' (%d x %d)\n " , __func__, params.fname_inp .c_str (), img0.nx , img0.ny );
73+
74+ // preprocess the image to f32
75+ if (vit_image_preprocess (img0, img1, model.hparams ))
76+ {
77+ fprintf (stderr, " processed, out dims : (%d x %d)\n " , img1.nx , img1.ny );
78+ }
79+
80+ // prepare for graph computation, memory allocation and results processing
81+ {
82+ static size_t buf_size = 3u * 1024 * 1024 ;
83+
84+ struct ggml_init_params ggml_params = {
85+ /* .mem_size =*/ buf_size,
86+ /* .mem_buffer =*/ NULL ,
87+ /* .no_alloc =*/ false ,
88+ };
89+
90+ state.ctx = ggml_init (ggml_params);
91+ state.prediction = ggml_new_tensor_4d (state.ctx , GGML_TYPE_F32, model.hparams .num_classes , 1 , 1 , 1 );
92+
93+ // printf("%s: Initialized context = %ld bytes\n", __func__, buf_size);
94+ }
95+
96+ {
97+ // run prediction on img1
98+ vit_predict (model, state, img1, params, predictions);
99+ }
100+
101+ // report timing
102+ {
103+ const int64_t t_main_end_us = ggml_time_us ();
104+ fprintf (stderr, " \n\n " );
105+ fprintf (stderr, " %s: model load time = %8.2f ms\n " , __func__, t_load_us / 1000 .0f );
106+ fprintf (stderr, " %s: processing time = %8.2f ms\n " , __func__, (t_main_end_us - t_main_start_us - t_load_us) / 1000 .0f );
107+ fprintf (stderr, " %s: total time = %8.2f ms\n " , __func__, (t_main_end_us - t_main_start_us) / 1000 .0f );
108+ }
109+
110+ ggml_free (model.ctx );
111+
112+ return 0 ;
113+ }
0 commit comments