Random forest - RTrees - getSplits - type mismatch

asked 2016-03-03 12:06:31 -0600

mask2007 gravatar image

updated 2016-03-09 13:56:09 -0600

Hello,

I'm trying to access RTrees's nodes and trees. Having the following code: ``` #include <iostream>

#include <opencv2\core\core.hpp>
#include <opencv2\highgui\highgui.hpp>
#include <opencv2\ml\ml.hpp>

using namespace std;
using namespace cv;
using namespace cv::ml;

static Ptr<TrainData> prepareTrainData(Mat& pts, Mat& classes)
{
    Mat samples;
    Mat(pts).reshape(1, (int)pts.rows).convertTo(samples, CV_32F);
    Ptr<TrainData> a = TrainData::create(samples, ROW_SAMPLE, classes);
    return a;
}

int main()
{
    Mat training = (Mat_<float>(7, 2) << 1, 1, 2, 2, 3, 3, 1, 2, 3, 4, 5, 4, 3, 2);
    Mat classes = (Mat_<int>(7, 1) << 0, 0, 0, 1, 1, 2, 2);

    cout << "training data" << endl;
    cout << endl << training << " " << endl << endl;
    cout << "classes" << endl;
    cout << endl << classes << " " << endl << endl;

    Ptr<RTrees> rtrees = RTrees::create();

    rtrees->setMaxDepth(14);
    rtrees->setMinSampleCount(2);
    rtrees->setRegressionAccuracy(0.f);
    rtrees->setUseSurrogates(false);
    rtrees->setMaxCategories(3);
    rtrees->setPriors(Mat());
    rtrees->setCalculateVarImportance(false);

    rtrees->setActiveVarCount(1);
    rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 5, 0));
    Ptr<TrainData> a = prepareTrainData(training, classes);
    rtrees->train(a);

    // my own prediction
    vector<int> roots = rtrees->getRoots();             // runtime error
    vector<DTrees::Node> nodes = rtrees->getNodes();    // runtime error
    vector<DTrees::Split> splits = rtrees->getSplits();  // runtime error

    Range range = Range(0, (int)roots.size());
    Mat sample = training.row(0);

    int nclasses = 2;
    vector<int> votes(nclasses, 0);
    Mat varIdx0 = a->getVarIdx();
    vector<int> varIdx;
    varIdx0.copyTo(varIdx);                                                        // wrong copy

    getchar();
    return 0;
}

```

I keep getting the memory access violation. Note that rt->getSplits() works by itself and I guess the problem is with the data type of splits.

Also, when I read training data information and want to copy it into a variable:

std::vector<int> catMap;
 data->getCapMap().copyTo(catMap);

it copies garbage into the carMap and I have to copy it element by element instead of using copyTo method. Do you have any idea about these two questions?

Thanks,

edit retag flag offensive close merge delete

Comments

which os / opencv version is it ? did you build it locally ?

(i can't reproduce any of your problems with ocv3.1 on win)

how do you setup your traindata ?

berak gravatar imageberak ( 2016-03-03 14:58:23 -0600 )edit

@berak it's Windows 10 and OpenCV 3.1. I'll put the whole code up there.

mask2007 gravatar imagemask2007 ( 2016-03-09 13:53:01 -0600 )edit