Ask Your Question
0

how do i load and use TrainData

asked 2016-12-29 03:03:56 -0600

shreya gravatar image

updated 2016-12-29 03:05:16 -0600

berak gravatar image

I downloaded training set for digit recognition. The training set is in csv format , the first column contains the responses and pixel value of the following response is stored the corresponding row. I am trying to write a c++ code in visual studio that uses the ml module of open cv 3.00 to train from the data. I can load the data using the TrainData::loadFromCSV function but i am unable to understand how will it generate the samples and responses matrix that is required to feed into the TrainData::create command my code:

void main()
{
    TrainData::loadFromCSV("C:/Users/Shreya Srivastava/Desktop/train.csv",0,1,10);
    Mat samples ( 28,28, CV_32FC1 );
    Mat responses(35591,0,CV_32F);
    TrainData::create(samples,ROW_SAMPLE,responses);
    imshow("samples",samples);

    waitKey();
}
edit retag flag offensive close merge delete

Comments

can you add a single line from your csv ? do you really have 10 responses per feature (aka hot-encoded) ?

what kind of machine learning do you plan to use ?

berak gravatar imageberak ( 2016-12-29 03:06:00 -0600 )edit

yes i can use ifstream interface of c++ but will it help ?

shreya gravatar imageshreya ( 2016-12-29 03:11:49 -0600 )edit

the training set has one response that is the digit that corresponds to the given image whose pixel values have been stored in the following columns of that row

shreya gravatar imageshreya ( 2016-12-29 03:25:50 -0600 )edit

is that digit at the end or at the start of the row ?

berak gravatar imageberak ( 2016-12-29 03:33:15 -0600 )edit
1

i am new to machine learning and was following the coursera open course by andrew ng i was trying to build a project on digit recognition using neural networks

shreya gravatar imageshreya ( 2016-12-29 03:34:31 -0600 )edit
1

and i also have a test set which is also in csv format such that each row contains the pixel values

shreya gravatar imageshreya ( 2016-12-29 03:35:56 -0600 )edit
1

the digit is in the start of the row

shreya gravatar imageshreya ( 2016-12-29 03:38:50 -0600 )edit

again, what kind of machine learning ? (i'll have to revise the answer below)

berak gravatar imageberak ( 2016-12-29 03:46:27 -0600 )edit
1

what exactly do you mean by what kind of machine learning ? I am using ANN

shreya gravatar imageshreya ( 2016-12-29 03:52:47 -0600 )edit

1 answer

Sort by ยป oldest newest most voted
0

answered 2016-12-29 03:31:52 -0600

berak gravatar image

updated 2016-12-29 04:04:25 -0600

note, that TrainData::loadFromCSV returns an instance:

Ptr<TrainData> tdata = TrainData::loadFromCSV("C:/Users/Shreya Srivastava/Desktop/train.csv",
       0, // lines to skip
       0, // 1st elem is the label
      -1); // only 1 response per line


// for NBayes, or LogisticRegression, using float labels, you could use it "as is":
ml->train(tdata);

// for an SVM or KNearest, you'd need to convert the responses to integer:
Mat trainData = tdata->getTrainSamples();
Mat trainLabels = tdata->getTrainResponses();
trainLabels.convertTo(trainLabels, CV_32S); // needed for SVM, KNearest
svm->train(trainData, 0, trainLabels);

//
// for an ANN, you need to "one-hot encode" the labels:
// for 5 classes, and a label of 2, it looks like:
// [0,0,1,0,0] 
// (this is the expected state of the output neurons)
//
Mat trainData = tdata->getTrainSamples();
Mat trainLabels = tdata->getTrainResponses();
int numClasses = 10; // assuming mnist
Mat hot(trainLabels.rows, numClasses, CV_32F, 0.0f); // all zero, initially
for (int i=0; i<trainLabels.rows; i++) {
        int id = (int)trainLabels.at<float>(i);
        hot.at<float>(i, id) = 1.0f; 
}
ann->train(trainData, 0, hot);

for your test set, just repeat the same steps:

Ptr<TrainData> tdata = TrainData::loadFromCSV("C:/Users/Shreya Srivastava/Desktop/test.csv", 0,0,-1);
Mat testData = tdata->getTrainSamples();
Mat testLabels = tdata->getTrainResponses();
Mat testResults;
ann->predict(testData, testResults);
float accuracy = float(countNonZero(testResults == testLabels)) / testLabels.rows;
edit flag offensive delete link more

Comments

1

thanks a lot for help

shreya gravatar imageshreya ( 2016-12-29 03:39:39 -0600 )edit

made some updates, so it's more specific to your situation !

berak gravatar imageberak ( 2016-12-29 04:05:07 -0600 )edit

http://answers.opencv.org/question/12... can you please help resolve this error

shreya gravatar imageshreya ( 2016-12-30 09:45:27 -0600 )edit

Question Tools

1 follower

Stats

Asked: 2016-12-29 03:03:56 -0600

Seen: 4,356 times

Last updated: Dec 29 '16