41 static bool const tbl[] = {
42 O,
X,
X,
X,
O,
O,
O,
X,
X,
O,
O,
O,
O,
X,
O,
O,
43 O,
O,
X,
X,
X,
O,
O,
O,
X,
X,
O,
O,
O,
O,
X,
O,
44 O,
O,
O,
X,
X,
X,
O,
O,
O,
X,
X,
O,
X,
O,
O,
O,
45 X,
O,
O,
O,
X,
X,
O,
O,
O,
O,
X,
X,
O,
X,
O,
O,
46 X,
X,
O,
O,
O,
X,
X,
O,
O,
O,
O,
X,
O,
O,
X,
O,
47 X,
X,
X,
O,
O,
O,
X,
X,
O,
O,
O,
O,
X,
O,
O,
O
53 tiny_dnn::core::backend_t backend_type = tiny_dnn::core::default_engine();
62 using fc = tiny_dnn::layers::fc;
63 using conv = tiny_dnn::layers::conv;
64 using ave_pool = tiny_dnn::layers::ave_pool;
65 using tanh = tiny_dnn::activation::tanh;
67 using tiny_dnn::core::connection_table;
68 using padding = tiny_dnn::padding;
70 (*net) << conv(32, 32, 5, 1, 6, padding::valid,
true, 1, 1, backend_type)
72 << ave_pool(28, 28, 6, 2)
74 << conv(14, 14, 5, 6, 16, connection_table(tbl, 6, 16),
75 padding::valid,
true, 1, 1, backend_type)
77 << ave_pool(10, 10, 16, 2)
79 << conv(5, 5, 5, 16, 120, padding::valid,
true, 1, 1, backend_type)
81 << fc(120, 10,
true, backend_type)
88 LINFO(
"Load training data from directory " << path);
91 std::vector<tiny_dnn::label_t> train_labels, test_labels;
92 std::vector<tiny_dnn::vec_t> train_images, test_images;
93 LINFO(
"Load training labels...");
94 tiny_dnn::parse_mnist_labels(std::string(path) +
"/train-labels.idx1-ubyte", &train_labels);
95 LINFO(
"Load training images...");
96 tiny_dnn::parse_mnist_images(std::string(path) +
"/train-images.idx3-ubyte", &train_images, -1.0, 1.0, 2, 2);
97 LINFO(
"Load test labels...");
98 tiny_dnn::parse_mnist_labels(std::string(path) +
"/t10k-labels.idx1-ubyte", &test_labels);
99 LINFO(
"Load test images...");
100 tiny_dnn::parse_mnist_images(std::string(path) +
"/t10k-images.idx3-ubyte", &test_images, -1.0, 1.0, 2, 2);
102 LINFO(
"Start training...");
103 int minibatch_size = 10;
108 auto on_enumerate_epoch = [&](){
109 LINFO(t.elapsed() <<
"s elapsed.");
110 tiny_dnn::result res =
net->test(test_images, test_labels);
111 LINFO(res.num_success <<
"/" << res.num_total <<
" success/total validation score so far");
115 auto on_enumerate_minibatch = [&](){ };
118 tiny_dnn::adagrad optimizer;
119 optimizer.alpha *=
static_cast<tiny_dnn::float_t
>(std::sqrt(minibatch_size));
121 net->train<tiny_dnn::mse>(optimizer, train_images, train_labels, minibatch_size, num_epochs,
122 on_enumerate_minibatch, on_enumerate_epoch);
124 LINFO(
"Training complete");
127 net->test(test_images, test_labels).print_detail(std::cout);