-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmnist-utils.cc
More file actions
227 lines (168 loc) · 6.61 KB
/
mnist-utils.cc
File metadata and controls
227 lines (168 loc) · 6.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
//========================================================================
// mnist-utils.cc
//========================================================================
// Utility functions for MNSIT dataset.
//
// Author: Yanghui Ou
// Date: Nov 6, 2019
#include <fstream>
#include <iostream>
#include "mnist-utils.h"
#include "IHandwritingRecSys.h"
#include "Vector.h"
#include "Image.h"
//------------------------------------------------------------------------
// constants
//------------------------------------------------------------------------
const int mnist_ncols = 28;
const int mnist_nrows = 28;
const int mnist_size = mnist_nrows * mnist_ncols;
const char* cursor_u = "\033[A"; // Move cursor up one line
const char* cursor_d = "\033[B"; // Move cursor up one line
const char* cursor_e = "\033[K"; // Erase to end of the line
//------------------------------------------------------------------------
// read_labeled_images
//------------------------------------------------------------------------
// Reads the images and labels file and fills a Vector<Image> with the
// corresponding labeled images.
void read_labeled_images( const std::string& images_path,
const std::string& labels_path,
Vector<Image>& vec,
size_t size )
{
std::ifstream myifs;
// MNIST has 4 misc values in image bin and 2 in label bin
const int n_misc_in_image_file = 4;
const int n_misc_in_label_file = 2;
//----------------------------------------------------------------------
// Read images
//----------------------------------------------------------------------
Image* labeled_images = new Image[size];
int* data = new int[mnist_size];
// Open binary file
myifs.open( images_path, std::ios::in | std::ios::binary );
// Skip the miscellaneous metadata (4B each)
int misc = 0;
for ( int i = 0; i < n_misc_in_image_file; i++ )
myifs.read( (char*)&misc, 4 );
// Read each image (28 x 28 bytes) and add to the Image
for ( size_t idx = 0; idx < size; idx++ ) {
for ( int i = 0; i < mnist_size; i++ ) {
int tmp = 0;
myifs.read( (char*)&tmp, 1 );
data[i] = tmp;
}
// Add this image to the array
labeled_images[idx] = Image( Vector<int>( data, mnist_size ),
mnist_ncols, mnist_nrows );
}
// Close file
myifs.close();
//----------------------------------------------------------------------
// Read images
//----------------------------------------------------------------------
// Open binary file
myifs.open( labels_path, std::ios::in | std::ios::binary );
// Skip the miscellaneous metadata (4B each)
for ( int i = 0; i < n_misc_in_label_file; i++ )
myifs.read( (char*)&misc, 4 );
// Read each label and label the corresponding image
for ( size_t idx = 0; idx < size; idx++ ) {
// Read the label
int tmp = 0;
myifs.read( (char*)&tmp, 1 );
char label_char = '0' + tmp;
// Set label for the corresponding image
labeled_images[idx].set_label( label_char );
}
//----------------------------------------------------------------------
// Copy to the image vector
//----------------------------------------------------------------------
vec = Vector<Image>(); // Clear the data
for ( size_t i = 0; i < size; i++ )
vec.push_back( labeled_images[i] );
// Clear dynamic memory
delete[] data;
delete[] labeled_images;
}
//------------------------------------------------------------------------
// train_and_classify
//------------------------------------------------------------------------
// Takes a handwriting recognition system, trains it with all of the
// labeled images in the training set, classifies all of the images in the
// inference set, and then checks how many classifications are correct,
// returning the accuracy as a double.
double train_and_classify( IHandwritingRecSys& hrs,
Vector<Image>& v_train,
Vector<Image>& v_test )
{
// Return 0 if testing set is empty
if ( v_test.size() == 0 )
return 0;
// Train the classifier
hrs.train( v_train );
// Classify the images in the testing set
size_t correct = 0;
size_t total = v_test.size();
for ( size_t i = 0; i < total; i++ ) {
// Check predicted results
if ( hrs.classify( v_test[i] ).get_label() == v_test[i].get_label() )
correct++;
// Prints out the number of correct predictions per 1000 images
if ( ( i + 1 ) % 1000 == 0 ) {
std::cout << "#images tested : " << i + 1 << std::endl;
std::cout << " - correctly predicted : " << i + 1 << std::endl;
}
}
// Calculate accuracy
return (double)correct / (double)total;
}
//------------------------------------------------------------------------
// classify_with_progress_bar
//------------------------------------------------------------------------
// Takes a handwriting recognition system, runs classfication on the given
// testing set and prints a progress bar, and returns the accuracy.
double classify_with_progress_bar( IHandwritingRecSys& hrs,
Vector<Image>& v_test )
{
// Return 0 if testing set is empty to avoid devide by 0
const size_t test_size = v_test.size();
if ( test_size == 0 )
return 0;
// Set up progress bar for inference
size_t n_frac = 70;
size_t frac_size = ( test_size > n_frac ) ? test_size / n_frac : 1;
size_t len_prog_bar = test_size / frac_size;
// Run inference
size_t num_correct = 0;
for ( size_t i = 0; i < test_size; i++ ) {
// Check result
char predicted = hrs.classify( v_test[i] ).get_label();
if ( predicted == v_test[i].get_label() )
num_correct++;
// Progress bar
if ( i % frac_size == 0 ) {
size_t n_markers = i / frac_size;
std::cout << "[ ";
for ( size_t j = 1; j < n_markers; j++ )
std::cout << "=";
std::cout << ">";
for ( size_t j = 0; j < len_prog_bar - n_markers; j++ )
std::cout << ".";
std::cout << " ]";
std::cout << cursor_e;
std::cout << std::endl;
std::cout << " - classifying image " << i << " of " << test_size;
std::cout << cursor_e;
std::cout << std::endl;
std::cout << " - #correct classifications : " << num_correct;
std::cout << cursor_e;
std::cout << std::endl;
std::cout << cursor_u << cursor_u << cursor_u;
}
}
// Delete output and reset cursor
std::cout << cursor_e << cursor_d << cursor_e << cursor_d
<< cursor_e << cursor_u << cursor_u;
return (double)num_correct / (double)test_size;
}