-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathBitNetworkDriver.cc
More file actions
70 lines (59 loc) · 2.5 KB
/
BitNetworkDriver.cc
File metadata and controls
70 lines (59 loc) · 2.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include "BitNetworkDriver.h"
BitNetworkDriver::BitNetworkDriver() {
int layers[] = {784, 16, 10};
bitNetwork = BitNetwork(layers, 3);
}
void BitNetworkDriver::run() {
cout << "Starting Bit Network Driver\n";
trainNetwork();
checkNetwork();
}
void BitNetworkDriver::trainNetwork() {
char trainingImagesFilename[] = "./train-images.idx3-ubyte";
char trainingLabelsFilename[] = "./train-labels.idx1-ubyte";
MNISTHelper mnistTrainingSet = MNISTHelper(trainingImagesFilename, trainingLabelsFilename);
cout << "~~~~~~~~~~~~~~~~ Start training ~~~~~~~~~~~~~~~ \n";
//mnistTrainingSet.number_of_images
for(int i = 0; i < mnistTrainingSet.number_of_images; ++i) {
if(i%10000 == 0) { cout << "Trained with " << setw(5) << i << "/" << mnistTrainingSet.number_of_images << " images\n"; }
MNISTImage mnistImage = mnistTrainingSet.getNext();
mnistImage.image.apply(Utility::normalizeImageToTernary);
Matrix<int> expected(10,1);
expected.setTo(-1);
expected[(int)mnistImage.label][0] = 1;
Matrix<int> image = Utility::matrixConverter<double, int>(mnistImage.image);
bitNetwork.backProp(image,expected);
//return;
}
bitNetwork.debug();
cout << "~~~~~~~~~~~~~~ Finished training ~~~~~~~~~~~~~~~ \n\n";
}
void BitNetworkDriver::checkNetwork()
{
char testingImagesFilename[] = "./t10k-images.idx3-ubyte";
char testingLabelsFilename[] = "./t10k-labels.idx1-ubyte";
MNISTHelper mnistTestingSet = MNISTHelper(testingImagesFilename, testingLabelsFilename);
cout << "~~~~~~~~~~~~~~~~~~ Start Check ~~~~~~~~~~~~~~~~~ \n";
int amtCorrect = 0;
for(int i=0;i<2;++i)
{
MNISTImage mnistImage = mnistTestingSet.getNext();
mnistImage.image.apply(Utility::normalizeImageToTernary);
Matrix<int> image = Utility::matrixConverter<double, int>(mnistImage.image);
Matrix<int> out = bitNetwork.forProp(image);
int outputNumber = 0;
double chance = 0.0;
for (long unsigned int i = 0; i < out.size(); i ++) {
if(chance < out[i][0]) {
outputNumber = i;
chance = out[i][0];
}
}
out.print();
cout << "Expected: " << (int)mnistImage.label << " Got: " << outputNumber << "\n";
if((int)mnistImage.label == outputNumber){amtCorrect++;}
}
cout << "Got " << amtCorrect << " out of " << mnistTestingSet.number_of_images << " correct\n";
cout << "Or " << ((double)amtCorrect/(double)mnistTestingSet.number_of_images)*100 << "%\n";
cout << "~~~~~~~~~~~~~~~~~~~ End Check ~~~~~~~~~~~~~~~~~~ \n";
}