digits_lenet.cpp
6.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
// This example provides a digital recognition based on LeNet-5 and connected component analysis.
// It makes it possible for OpenCV beginner to run dnn models in real time using only CPU.
// It can read pictures from the camera in real time to make predictions, and display the recognized digits as overlays on top of the original digits.
//
// In order to achieve a better display effect, please write the number on white paper and occupy the entire camera.
//
// You can follow the following guide to train LeNet-5 by yourself using the MNIST dataset.
// https://github.com/intel/caffe/blob/a3d5b022fe026e9092fc7abc7654b1162ab9940d/examples/mnist/readme.md
//
// You can also download already trained model directly.
// https://github.com/zihaomu/opencv_digit_text_recognition_demo/tree/master/src
#include <opencv2/imgproc.hpp>
#include <opencv2/highgui.hpp>
#include <opencv2/dnn.hpp>
#include <iostream>
#include <vector>
using namespace cv;
using namespace cv::dnn;
const char *keys =
"{ help h | | Print help message. }"
"{ input i | | Path to input image or video file. Skip this argument to capture frames from a camera.}"
"{ device | 0 | camera device number. }"
"{ modelBin | | Path to a binary .caffemodel file contains trained network.}"
"{ modelTxt | | Path to a .prototxt file contains the model definition of trained network.}"
"{ width | 640 | Set the width of the camera }"
"{ height | 480 | Set the height of the camera }"
"{ thr | 0.7 | Confidence threshold. }";
// Find best class for the blob (i.e. class with maximal probability)
static void getMaxClass(const Mat &probBlob, int &classId, double &classProb);
void predictor(Net net, const Mat &roi, int &class_id, double &probability);
int main(int argc, char **argv)
{
// Parse command line arguments.
CommandLineParser parser(argc, argv, keys);
if (argc == 1 || parser.has("help"))
{
parser.printMessage();
return 0;
}
int vWidth = parser.get<int>("width");
int vHeight = parser.get<int>("height");
float confThreshold = parser.get<float>("thr");
std::string modelTxt = parser.get<String>("modelTxt");
std::string modelBin = parser.get<String>("modelBin");
Net net;
try
{
net = readNet(modelTxt, modelBin);
}
catch (cv::Exception &ee)
{
std::cerr << "Exception: " << ee.what() << std::endl;
std::cout << "Can't load the network by using the flowing files:" << std::endl;
std::cout << "modelTxt: " << modelTxt << std::endl;
std::cout << "modelBin: " << modelBin << std::endl;
return 1;
}
const std::string resultWinName = "Please write the number on white paper and occupy the entire camera.";
const std::string preWinName = "Preprocessing";
namedWindow(preWinName, WINDOW_AUTOSIZE);
namedWindow(resultWinName, WINDOW_AUTOSIZE);
Mat labels, stats, centroids;
Point position;
Rect getRectangle;
bool ifDrawingBox = false;
int classId = 0;
double probability = 0;
Rect basicRect = Rect(0, 0, vWidth, vHeight);
Mat rawImage;
double fps = 0;
// Open a video file or an image file or a camera stream.
VideoCapture cap;
if (parser.has("input"))
cap.open(parser.get<String>("input"));
else
cap.open(parser.get<int>("device"));
TickMeter tm;
while (waitKey(1) < 0)
{
cap >> rawImage;
if (rawImage.empty())
{
waitKey();
break;
}
tm.reset();
tm.start();
Mat image = rawImage.clone();
// Image preprocessing
cvtColor(image, image, COLOR_BGR2GRAY);
GaussianBlur(image, image, Size(3, 3), 2, 2);
adaptiveThreshold(image, image, 255, ADAPTIVE_THRESH_MEAN_C, THRESH_BINARY, 25, 10);
bitwise_not(image, image);
Mat element = getStructuringElement(MORPH_RECT, Size(3, 3), Point(-1,-1));
dilate(image, image, element, Point(-1,-1), 1);
// Find connected component
int nccomps = cv::connectedComponentsWithStats(image, labels, stats, centroids);
for (int i = 1; i < nccomps; i++)
{
ifDrawingBox = false;
// Extend the bounding box of connected component for easier recognition
if (stats.at<int>(i - 1, CC_STAT_AREA) > 80 && stats.at<int>(i - 1, CC_STAT_AREA) < 3000)
{
ifDrawingBox = true;
int left = stats.at<int>(i - 1, CC_STAT_HEIGHT) / 4;
getRectangle = Rect(stats.at<int>(i - 1, CC_STAT_LEFT) - left, stats.at<int>(i - 1, CC_STAT_TOP) - left, stats.at<int>(i - 1, CC_STAT_WIDTH) + 2 * left, stats.at<int>(i - 1, CC_STAT_HEIGHT) + 2 * left);
getRectangle &= basicRect;
}
if (ifDrawingBox && !getRectangle.empty())
{
Mat roi = image(getRectangle);
predictor(net, roi, classId, probability);
if (probability < confThreshold)
continue;
rectangle(rawImage, getRectangle, Scalar(128, 255, 128), 2);
position = Point(getRectangle.br().x - 7, getRectangle.br().y + 25);
putText(rawImage, std::to_string(classId), position, 3, 1.0, Scalar(128, 128, 255), 2);
}
}
tm.stop();
fps = 1 / tm.getTimeSec();
std::string fpsString = format("Inference FPS: %.2f.", fps);
putText(rawImage, fpsString, Point(5, 20), FONT_HERSHEY_SIMPLEX, 0.6, Scalar(128, 255, 128));
imshow(resultWinName, rawImage);
imshow(preWinName, image);
}
return 0;
}
static void getMaxClass(const Mat &probBlob, int &classId, double &classProb)
{
Mat probMat = probBlob.reshape(1, 1);
Point classNumber;
minMaxLoc(probMat, NULL, &classProb, NULL, &classNumber);
classId = classNumber.x;
}
void predictor(Net net, const Mat &roi, int &classId, double &probability)
{
Mat pred;
// Convert Mat to batch of images
Mat inputBlob = dnn::blobFromImage(roi, 1.0, Size(28, 28));
// Set the network input
net.setInput(inputBlob);
// Compute output
pred = net.forward();
getMaxClass(pred, classId, probability);
}