Blame view

3rdparty/opencv-4.5.4/samples/cpp/tree_engine.cpp 3.87 KB
f4334277   Hu Chunming   提交3rdparty
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
  #include "opencv2/ml.hpp"
  #include "opencv2/core.hpp"
  #include "opencv2/core/utility.hpp"
  #include <stdio.h>
  #include <string>
  #include <map>
  
  using namespace cv;
  using namespace cv::ml;
  
  static void help(char** argv)
  {
      printf(
          "\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees.\n"
          "Usage:\n\t%s [-r=<response_column>] [-ts=type_spec] <csv filename>\n"
          "where -r=<response_column> specified the 0-based index of the response (0 by default)\n"
          "-ts= specifies the var type spec in the form ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\n"
          "<csv filename> is the name of training data file in comma-separated value format\n\n", argv[0]);
  }
  
  static void train_and_print_errs(Ptr<StatModel> model, const Ptr<TrainData>& data)
  {
      bool ok = model->train(data);
      if( !ok )
      {
          printf("Training failed\n");
      }
      else
      {
          printf( "train error: %f\n", model->calcError(data, false, noArray()) );
          printf( "test error: %f\n\n", model->calcError(data, true, noArray()) );
      }
  }
  
  int main(int argc, char** argv)
  {
      cv::CommandLineParser parser(argc, argv, "{ help h | | }{r | 0 | }{ts | | }{@input | | }");
      if (parser.has("help"))
      {
          help(argv);
          return 0;
      }
      std::string filename = parser.get<std::string>("@input");
      int response_idx;
      std::string typespec;
      response_idx = parser.get<int>("r");
      typespec = parser.get<std::string>("ts");
      if( filename.empty() || !parser.check() )
      {
          parser.printErrors();
          help(argv);
          return 0;
      }
      printf("\nReading in %s...\n\n",filename.c_str());
      const double train_test_split_ratio = 0.5;
  
      Ptr<TrainData> data = TrainData::loadFromCSV(filename, 0, response_idx, response_idx+1, typespec);
      if( data.empty() )
      {
          printf("ERROR: File %s can not be read\n", filename.c_str());
          return 0;
      }
  
      data->setTrainTestSplitRatio(train_test_split_ratio);
      std::cout << "Test/Train: " << data->getNTestSamples() << "/" << data->getNTrainSamples();
  
      printf("======DTREE=====\n");
      Ptr<DTrees> dtree = DTrees::create();
      dtree->setMaxDepth(10);
      dtree->setMinSampleCount(2);
      dtree->setRegressionAccuracy(0);
      dtree->setUseSurrogates(false);
      dtree->setMaxCategories(16);
      dtree->setCVFolds(0);
      dtree->setUse1SERule(false);
      dtree->setTruncatePrunedTree(false);
      dtree->setPriors(Mat());
      train_and_print_errs(dtree, data);
  
      if( (int)data->getClassLabels().total() <= 2 ) // regression or 2-class classification problem
      {
          printf("======BOOST=====\n");
          Ptr<Boost> boost = Boost::create();
          boost->setBoostType(Boost::GENTLE);
          boost->setWeakCount(100);
          boost->setWeightTrimRate(0.95);
          boost->setMaxDepth(2);
          boost->setUseSurrogates(false);
          boost->setPriors(Mat());
          train_and_print_errs(boost, data);
      }
  
      printf("======RTREES=====\n");
      Ptr<RTrees> rtrees = RTrees::create();
      rtrees->setMaxDepth(10);
      rtrees->setMinSampleCount(2);
      rtrees->setRegressionAccuracy(0);
      rtrees->setUseSurrogates(false);
      rtrees->setMaxCategories(16);
      rtrees->setPriors(Mat());
      rtrees->setCalculateVarImportance(true);
      rtrees->setActiveVarCount(0);
      rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 0));
      train_and_print_errs(rtrees, data);
      cv::Mat ref_labels = data->getClassLabels();
      cv::Mat test_data = data->getTestSampleIdx();
      cv::Mat predict_labels;
      rtrees->predict(data->getSamples(), predict_labels);
  
      cv::Mat variable_importance = rtrees->getVarImportance();
      std::cout << "Estimated variable importance" << std::endl;
      for (int i = 0; i < variable_importance.rows; i++) {
          std::cout << "Variable " << i << ": " << variable_importance.at<float>(i, 0) << std::endl;
      }
      return 0;
  }