Blame view

3rdparty/opencv-4.5.4/modules/ml/test/test_lr.cpp 2.83 KB
f4334277   Hu Chunming   提交3rdparty
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
  // This file is part of OpenCV project.
  // It is subject to the license terms in the LICENSE file found in the top-level directory
  // of this distribution and at http://opencv.org/license.html.
  //
  // AUTHOR: Rahul Kavi rahulkavi[at]live[at]com
  
  //
  // Test data uses subset of data from the popular Iris Dataset (1936):
  // - http://archive.ics.uci.edu/ml/datasets/Iris
  // - https://en.wikipedia.org/wiki/Iris_flower_data_set
  //
  
  #include "test_precomp.hpp"
  
  namespace opencv_test { namespace {
  
  TEST(ML_LR, accuracy)
  {
      std::string dataFileName = findDataFile("iris.data");
      Ptr<TrainData> tdata = TrainData::loadFromCSV(dataFileName, 0);
      ASSERT_FALSE(tdata.empty());
  
      Ptr<LogisticRegression> p = LogisticRegression::create();
      p->setLearningRate(1.0);
      p->setIterations(10001);
      p->setRegularization(LogisticRegression::REG_L2);
      p->setTrainMethod(LogisticRegression::BATCH);
      p->setMiniBatchSize(10);
      p->train(tdata);
  
      Mat responses;
      p->predict(tdata->getSamples(), responses);
  
      float error = 1000;
      EXPECT_TRUE(calculateError(responses, tdata->getResponses(), error));
      EXPECT_LE(error, 0.05f);
  }
  
  //==================================================================================================
  
  TEST(ML_LR, save_load)
  {
      string dataFileName = findDataFile("iris.data");
      Ptr<TrainData> tdata = TrainData::loadFromCSV(dataFileName, 0);
      ASSERT_FALSE(tdata.empty());
      Mat responses1, responses2;
      Mat learnt_mat1, learnt_mat2;
      String filename = tempfile(".xml");
      {
          Ptr<LogisticRegression> lr1 = LogisticRegression::create();
          lr1->setLearningRate(1.0);
          lr1->setIterations(10001);
          lr1->setRegularization(LogisticRegression::REG_L2);
          lr1->setTrainMethod(LogisticRegression::BATCH);
          lr1->setMiniBatchSize(10);
          ASSERT_NO_THROW(lr1->train(tdata));
          ASSERT_NO_THROW(lr1->predict(tdata->getSamples(), responses1));
          ASSERT_NO_THROW(lr1->save(filename));
          learnt_mat1 = lr1->get_learnt_thetas();
      }
      {
          Ptr<LogisticRegression> lr2;
          ASSERT_NO_THROW(lr2 = Algorithm::load<LogisticRegression>(filename));
          ASSERT_NO_THROW(lr2->predict(tdata->getSamples(), responses2));
          learnt_mat2 = lr2->get_learnt_thetas();
      }
      // compare difference in prediction outputs and stored inputs
      EXPECT_MAT_NEAR(responses1, responses2, 0.f);
  
      Mat comp_learnt_mats;
      comp_learnt_mats = (learnt_mat1 == learnt_mat2);
      comp_learnt_mats = comp_learnt_mats.reshape(1, comp_learnt_mats.rows*comp_learnt_mats.cols);
      comp_learnt_mats.convertTo(comp_learnt_mats, CV_32S);
      comp_learnt_mats = comp_learnt_mats/255;
      // check if there is any difference between computed learnt mat and retrieved mat
      EXPECT_EQ(comp_learnt_mats.rows, sum(comp_learnt_mats)[0]);
  
      remove( filename.c_str() );
  }
  
  }} // namespace