Blame view

3rdparty/opencv-4.5.4/modules/ml/test/test_save_load.cpp 3.55 KB
f4334277   Hu Chunming   提交3rdparty
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
  // This file is part of OpenCV project.
  // It is subject to the license terms in the LICENSE file found in the top-level directory
  // of this distribution and at http://opencv.org/license.html.
  
  #include "test_precomp.hpp"
  
  namespace opencv_test { namespace {
  
  
  void randomFillCategories(const string & filename, Mat & input)
  {
      Mat catMap;
      Mat catCount;
      std::vector<uchar> varTypes;
  
      FileStorage fs(filename, FileStorage::READ);
      FileNode root = fs.getFirstTopLevelNode();
      root["cat_map"] >> catMap;
      root["cat_count"] >> catCount;
      root["var_type"] >> varTypes;
  
      int offset = 0;
      int countOffset = 0;
      uint var = 0, varCount = (uint)varTypes.size();
      for (; var < varCount; ++var)
      {
          if (varTypes[var] == ml::VAR_CATEGORICAL)
          {
              int size = catCount.at<int>(0, countOffset);
              for (int row = 0; row < input.rows; ++row)
              {
                  int randomChosenIndex = offset + ((uint)cv::theRNG()) % size;
                  int value = catMap.at<int>(0, randomChosenIndex);
                  input.at<float>(row, var) = (float)value;
              }
              offset += size;
              ++countOffset;
          }
      }
  }
  
  //==================================================================================================
  
  typedef tuple<string, string> ML_Legacy_Param;
  typedef testing::TestWithParam< ML_Legacy_Param > ML_Legacy_Params;
  
  TEST_P(ML_Legacy_Params, legacy_load)
  {
      const string modelName = get<0>(GetParam());
      const string dataName = get<1>(GetParam());
      const string filename = findDataFile("legacy/" + modelName + "_" + dataName + ".xml");
      const bool isTree = modelName == CV_BOOST || modelName == CV_DTREE || modelName == CV_RTREES;
  
      Ptr<StatModel> model;
      if (modelName == CV_BOOST)
          model = Algorithm::load<Boost>(filename);
      else if (modelName == CV_ANN)
          model = Algorithm::load<ANN_MLP>(filename);
      else if (modelName == CV_DTREE)
          model = Algorithm::load<DTrees>(filename);
      else if (modelName == CV_NBAYES)
          model = Algorithm::load<NormalBayesClassifier>(filename);
      else if (modelName == CV_SVM)
          model = Algorithm::load<SVM>(filename);
      else if (modelName == CV_RTREES)
          model = Algorithm::load<RTrees>(filename);
      else if (modelName == CV_SVMSGD)
          model = Algorithm::load<SVMSGD>(filename);
      ASSERT_TRUE(model);
  
      Mat input = Mat(isTree ? 10 : 1, model->getVarCount(), CV_32F);
      cv::theRNG().fill(input, RNG::UNIFORM, 0, 40);
  
      if (isTree)
          randomFillCategories(filename, input);
  
      Mat output;
      EXPECT_NO_THROW(model->predict(input, output, StatModel::RAW_OUTPUT | (isTree ? DTrees::PREDICT_SUM : 0)));
      // just check if no internal assertions or errors thrown
  }
  
  ML_Legacy_Param param_list[] = {
      ML_Legacy_Param(CV_ANN, "waveform"),
      ML_Legacy_Param(CV_BOOST, "adult"),
      ML_Legacy_Param(CV_BOOST, "1"),
      ML_Legacy_Param(CV_BOOST, "2"),
      ML_Legacy_Param(CV_BOOST, "3"),
      ML_Legacy_Param(CV_DTREE, "abalone"),
      ML_Legacy_Param(CV_DTREE, "mushroom"),
      ML_Legacy_Param(CV_NBAYES, "waveform"),
      ML_Legacy_Param(CV_SVM, "poletelecomm"),
      ML_Legacy_Param(CV_SVM, "waveform"),
      ML_Legacy_Param(CV_RTREES, "waveform"),
      ML_Legacy_Param(CV_SVMSGD, "waveform"),
  };
  
  INSTANTIATE_TEST_CASE_P(/**/, ML_Legacy_Params, testing::ValuesIn(param_list));
  
  /*TEST(ML_SVM, throw_exception_when_save_untrained_model)
  {
      Ptr<cv::ml::SVM> svm;
      string filename = tempfile("svm.xml");
      ASSERT_THROW(svm.save(filename.c_str()), Exception);
      remove(filename.c_str());
  }*/
  
  }} // namespace