JeVoisBase  1.21
JeVois Smart Embedded Machine Vision Toolkit Base Modules
Share this page:
Loading...
Searching...
No Matches
ObjectRecognitionMNIST.C
Go to the documentation of this file.
1// ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
2//
3// JeVois Smart Embedded Machine Vision Toolkit - Copyright (C) 2016 by Laurent Itti, the University of Southern
4// California (USC), and iLab at USC. See http://iLab.usc.edu and http://jevois.org for information about this project.
5//
6// This file is part of the JeVois Smart Embedded Machine Vision Toolkit. This program is free software; you can
7// redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software
8// Foundation, version 2. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
9// without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public
10// License for more details. You should have received a copy of the GNU General Public License along with this program;
11// if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
12//
13// Contact information: Laurent Itti - 3641 Watt Way, HNB-07A - Los Angeles, CA 90089-2520 - USA.
14// Tel: +1 213 740 3527 - itti@pollux.usc.edu - http://iLab.usc.edu - http://jevois.org
15// ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
16/*! \file */
17
19#include "tiny-dnn/tiny_dnn/tiny_dnn.h"
20#include <jevois/Debug/Log.H>
21
22// ####################################################################################################
24 ObjectRecognition<tiny_dnn::sequential>(instance)
25{
26 // Note: base class constructor allocates net
27}
28
29// ####################################################################################################
31{
32 // Nothing to do, base class destructor will de-allocate the network
33}
34
35// ####################################################################################################
37{
38 // LeNet for MNIST handwritten digit recognition: 32x32 in, 10 classes out:
39#define O true
40#define X false
41 static bool const tbl[] = {
42 O, X, X, X, O, O, O, X, X, O, O, O, O, X, O, O,
43 O, O, X, X, X, O, O, O, X, X, O, O, O, O, X, O,
44 O, O, O, X, X, X, O, O, O, X, X, O, X, O, O, O,
45 X, O, O, O, X, X, O, O, O, O, X, X, O, X, O, O,
46 X, X, O, O, O, X, X, O, O, O, O, X, O, O, X, O,
47 X, X, X, O, O, O, X, X, O, O, O, O, X, O, O, O
48 };
49#undef O
50#undef X
51 // by default will use backend_t::tiny_dnn unless you compiled
52 // with -DUSE_AVX=ON and your device supports AVX intrinsics
53 tiny_dnn::core::backend_t backend_type = tiny_dnn::core::default_engine();
54
55 // Construct network:
56// construct nets
57 //
58 // C : convolution
59 // S : sub-sampling
60 // F : fully connected
61 // clang-format off
62 using fc = tiny_dnn::layers::fc;
63 using conv = tiny_dnn::layers::conv;
64 using ave_pool = tiny_dnn::layers::ave_pool;
65 using tanh = tiny_dnn::activation::tanh;
66
67 using tiny_dnn::core::connection_table;
68 using padding = tiny_dnn::padding;
69
70 (*net) << conv(32, 32, 5, 1, 6, padding::valid, true, 1, 1, backend_type) // C1, 1@32x32-in, 6@28x28-out
71 << tanh()
72 << ave_pool(28, 28, 6, 2) // S2, 6@28x28-in, 6@14x14-out
73 << tanh()
74 << conv(14, 14, 5, 6, 16, connection_table(tbl, 6, 16),
75 padding::valid, true, 1, 1, backend_type) // C3, 6@14x14-in, 16@10x10-out
76 << tanh()
77 << ave_pool(10, 10, 16, 2) // S4, 16@10x10-in, 16@5x5-out
78 << tanh()
79 << conv(5, 5, 5, 16, 120, padding::valid, true, 1, 1, backend_type) // C5, 16@5x5-in, 120@1x1-out
80 << tanh()
81 << fc(120, 10, true, backend_type) // F6, 120-in, 10-out
82 << tanh();
83}
84
85// ####################################################################################################
86void ObjectRecognitionMNIST::train(std::string const & path)
87{
88 LINFO("Load training data from directory " << path);
89
90 // Load MNIST dataset:
91 std::vector<tiny_dnn::label_t> train_labels, test_labels;
92 std::vector<tiny_dnn::vec_t> train_images, test_images;
93 LINFO("Load training labels...");
94 tiny_dnn::parse_mnist_labels(std::string(path) + "/train-labels.idx1-ubyte", &train_labels);
95 LINFO("Load training images...");
96 tiny_dnn::parse_mnist_images(std::string(path) + "/train-images.idx3-ubyte", &train_images, -1.0, 1.0, 2, 2);
97 LINFO("Load test labels...");
98 tiny_dnn::parse_mnist_labels(std::string(path) + "/t10k-labels.idx1-ubyte", &test_labels);
99 LINFO("Load test images...");
100 tiny_dnn::parse_mnist_images(std::string(path) + "/t10k-images.idx3-ubyte", &test_images, -1.0, 1.0, 2, 2);
101
102 LINFO("Start training...");
103 int minibatch_size = 10;
104 int num_epochs = 30;
105 tiny_dnn::timer t;
106
107 // Create callbacks:
108 auto on_enumerate_epoch = [&](){
109 LINFO(t.elapsed() << "s elapsed.");
110 tiny_dnn::result res = net->test(test_images, test_labels);
111 LINFO(res.num_success << "/" << res.num_total << " success/total validation score so far");
112 t.restart();
113 };
114
115 auto on_enumerate_minibatch = [&](){ };
116
117 // Training:
118 tiny_dnn::adagrad optimizer;
119 optimizer.alpha *= static_cast<tiny_dnn::float_t>(std::sqrt(minibatch_size));
120
121 net->train<tiny_dnn::mse>(optimizer, train_images, train_labels, minibatch_size, num_epochs,
122 on_enumerate_minibatch, on_enumerate_epoch);
123
124 LINFO("Training complete");
125
126 // Test and show results:
127 net->test(test_images, test_labels).print_detail(std::cout);
128}
129
130// ####################################################################################################
131std::string const & ObjectRecognitionMNIST::category(size_t idx) const
132{
133 static std::vector<std::string> const names = { "0", "1", "2", "3", "4", "5", "6", "7", "8", "9" };
134
135 if (idx >= names.size()) LFATAL("Category index out of bounds");
136
137 return names[idx];
138}
#define X
#define O
virtual std::string const & category(size_t idx) const override
Return the name of a given category (0-based index in the vector of results)
virtual void define() override
Define the network structure.
virtual ~ObjectRecognitionMNIST()
Destructor.
ObjectRecognitionMNIST(std::string const &instance)
Constructor, loads the given CNN, its sizes must match our (fixed) internal network structure.
virtual void train(std::string const &path) override
Train the network.
Wrapper around a neural network implemented by with the tiny-dnn framework by Taiga Nomi.
tiny_dnn::network< tiny_dnn::sequential > * net
#define LFATAL(msg)
#define LINFO(msg)