JeVoisBase  1.5
JeVois Smart Embedded Machine Vision Toolkit Base Modules
Share this page:
ObjectRecognitionMNIST.C
Go to the documentation of this file.
1 // ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
2 //
3 // JeVois Smart Embedded Machine Vision Toolkit - Copyright (C) 2016 by Laurent Itti, the University of Southern
4 // California (USC), and iLab at USC. See http://iLab.usc.edu and http://jevois.org for information about this project.
5 //
6 // This file is part of the JeVois Smart Embedded Machine Vision Toolkit. This program is free software; you can
7 // redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software
8 // Foundation, version 2. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
9 // without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public
10 // License for more details. You should have received a copy of the GNU General Public License along with this program;
11 // if not, write to the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
12 //
13 // Contact information: Laurent Itti - 3641 Watt Way, HNB-07A - Los Angeles, CA 90089-2520 - USA.
14 // Tel: +1 213 740 3527 - itti@pollux.usc.edu - http://iLab.usc.edu - http://jevois.org
15 // ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////
16 /*! \file */
17 
19 #include "tiny-dnn/tiny_dnn/tiny_dnn.h"
20 #include <jevois/Debug/Log.H>
21 
22 // ####################################################################################################
23 ObjectRecognitionMNIST::ObjectRecognitionMNIST(std::string const & instance) :
24  ObjectRecognition<tiny_dnn::sequential>(instance)
25 {
26  // Note: base class constructor allocates net
27 }
28 
29 // ####################################################################################################
31 {
32  // Nothing to do, base class destructor will de-allocate the network
33 }
34 
35 // ####################################################################################################
37 {
38  // LeNet for MNIST handwritten digit recognition: 32x32 in, 10 classes out:
39 #define O true
40 #define X false
41  static bool const tbl[] = {
42  O, X, X, X, O, O, O, X, X, O, O, O, O, X, O, O,
43  O, O, X, X, X, O, O, O, X, X, O, O, O, O, X, O,
44  O, O, O, X, X, X, O, O, O, X, X, O, X, O, O, O,
45  X, O, O, O, X, X, O, O, O, O, X, X, O, X, O, O,
46  X, X, O, O, O, X, X, O, O, O, O, X, O, O, X, O,
47  X, X, X, O, O, O, X, X, O, O, O, O, X, O, O, O
48  };
49 #undef O
50 #undef X
51  // by default will use backend_t::tiny_dnn unless you compiled
52  // with -DUSE_AVX=ON and your device supports AVX intrinsics
53  tiny_dnn::core::backend_t backend_type = tiny_dnn::core::default_engine();
54 
55  // Construct network:
56 // construct nets
57  //
58  // C : convolution
59  // S : sub-sampling
60  // F : fully connected
61  // clang-format off
62  using fc = tiny_dnn::layers::fc;
63  using conv = tiny_dnn::layers::conv;
64  using ave_pool = tiny_dnn::layers::ave_pool;
65  using tanh = tiny_dnn::activation::tanh;
66 
67  using tiny_dnn::core::connection_table;
68  using padding = tiny_dnn::padding;
69 
70  (*net) << conv(32, 32, 5, 1, 6, padding::valid, true, 1, 1, backend_type) // C1, 1@32x32-in, 6@28x28-out
71  << tanh()
72  << ave_pool(28, 28, 6, 2) // S2, 6@28x28-in, 6@14x14-out
73  << tanh()
74  << conv(14, 14, 5, 6, 16, connection_table(tbl, 6, 16),
75  padding::valid, true, 1, 1, backend_type) // C3, 6@14x14-in, 16@10x10-out
76  << tanh()
77  << ave_pool(10, 10, 16, 2) // S4, 16@10x10-in, 16@5x5-out
78  << tanh()
79  << conv(5, 5, 5, 16, 120, padding::valid, true, 1, 1, backend_type) // C5, 16@5x5-in, 120@1x1-out
80  << tanh()
81  << fc(120, 10, true, backend_type) // F6, 120-in, 10-out
82  << tanh();
83 }
84 
85 // ####################################################################################################
86 void ObjectRecognitionMNIST::train(std::string const & path)
87 {
88  LINFO("Load training data from directory " << path);
89 
90  // Load MNIST dataset:
91  std::vector<tiny_dnn::label_t> train_labels, test_labels;
92  std::vector<tiny_dnn::vec_t> train_images, test_images;
93  LINFO("Load training labels...");
94  tiny_dnn::parse_mnist_labels(std::string(path) + "/train-labels.idx1-ubyte", &train_labels);
95  LINFO("Load training images...");
96  tiny_dnn::parse_mnist_images(std::string(path) + "/train-images.idx3-ubyte", &train_images, -1.0, 1.0, 2, 2);
97  LINFO("Load test labels...");
98  tiny_dnn::parse_mnist_labels(std::string(path) + "/t10k-labels.idx1-ubyte", &test_labels);
99  LINFO("Load test images...");
100  tiny_dnn::parse_mnist_images(std::string(path) + "/t10k-images.idx3-ubyte", &test_images, -1.0, 1.0, 2, 2);
101 
102  LINFO("Start training...");
103  int minibatch_size = 10;
104  int num_epochs = 30;
105  tiny_dnn::timer t;
106 
107  // Create callbacks:
108  auto on_enumerate_epoch = [&](){
109  LINFO(t.elapsed() << "s elapsed.");
110  tiny_dnn::result res = net->test(test_images, test_labels);
111  LINFO(res.num_success << "/" << res.num_total << " success/total validation score so far");
112  t.restart();
113  };
114 
115  auto on_enumerate_minibatch = [&](){ };
116 
117  // Training:
118  tiny_dnn::adagrad optimizer;
119  optimizer.alpha *= static_cast<tiny_dnn::float_t>(std::sqrt(minibatch_size));
120 
121  net->train<tiny_dnn::mse>(optimizer, train_images, train_labels, minibatch_size, num_epochs,
122  on_enumerate_minibatch, on_enumerate_epoch);
123 
124  LINFO("Training complete");
125 
126  // Test and show results:
127  net->test(test_images, test_labels).print_detail(std::cout);
128 }
129 
130 // ####################################################################################################
131 std::string const & ObjectRecognitionMNIST::category(size_t idx) const
132 {
133  static std::vector<std::string> const names = { "0", "1", "2", "3", "4", "5", "6", "7", "8", "9" };
134 
135  if (idx >= names.size()) LFATAL("Category index out of bounds");
136 
137  return names[idx];
138 }
virtual void define() override
Define the network structure.
virtual ~ObjectRecognitionMNIST()
Destructor.
ObjectRecognitionMNIST(std::string const &instance)
Constructor, loads the given CNN, its sizes must match our (fixed) internal network structure...
virtual std::string const & category(size_t idx) const override
Return the name of a given category (0-based index in the vector of results)
#define X
#define LFATAL(msg)
virtual void train(std::string const &path) override
Train the network.
Wrapper around a neural network implemented by with the tiny-dnn framework by Taiga Nomi...
#define O
tiny_dnn::network< tiny_dnn::sequential > * net
#define LINFO(msg)