Ask Your Question

Revision history [back]

Good to see that mlp has relu activation function but returns nan??

The following program uses ann_mlp to do the xor function. It works fine using sigmoid and gaussian activation functions but predicts nan for relu & leaky relu Any ideas?? Using CV 4.0.1

    #include <opencv2/core.hpp>
#include <opencv2/ml/ml.hpp>

#include <iostream>
#include <iomanip>

using namespace cv;
using namespace ml;
using namespace std;

void PrintWts(Ptr<ANN_MLP>& mlp) {
    Mat wts;

    Mat sizes = mlp->getLayerSizes();

    for (int layerIndex = 0; layerIndex < sizes.size().area(); ++layerIndex) {
        cout << format("layer %d: size(%d)\n", layerIndex, sizes.at<int>(layerIndex));
        cout << "weights:\n";
        wts = mlp->getWeights(layerIndex);
        cout << wts << "\n\n";
    }
}

void main() {
    vector<int> layerSizes = { 2, 4, 1 };
    vector<float> inputTrainingData = {
        0.0, 0.0,
        0.0, 1.0,
        1.0, 0.0,
        1.0, 1.0
    };
    Mat inputTrainingMat(Size(2, 4), CV_32FC1, inputTrainingData.data());
    vector<float> outputTrainingData = {
        0.0,
        1.0,
        1.0,
        0.0,
    };
    Mat outputTrainingMat(Size(1, 4), CV_32FC1, outputTrainingData.data());

    Ptr<TrainData> trainingData = TrainData::create(
        inputTrainingMat,
        ROW_SAMPLE,
        outputTrainingMat
    );
    TermCriteria termCrit = TermCriteria(
        TermCriteria::Type::COUNT + TermCriteria::Type::EPS,
        220,
        0.00000001
    );

    Ptr<ANN_MLP> mlp = ANN_MLP::create();
    mlp->setLayerSizes(layerSizes);
    mlp->setActivationFunction(ANN_MLP::SIGMOID_SYM);
    mlp->setTermCriteria(termCrit);
    mlp->setTrainMethod(ANN_MLP::BACKPROP, 0.1, 0.1);

    mlp->train(trainingData);

    PrintWts(mlp);

    if (mlp->isTrained()) {
        for (int i = 0; i < inputTrainingMat.rows; i++) {
            Mat sample = inputTrainingMat.row(i);
            Mat result;
            mlp->predict(sample, result);
            cout << sample << " -> " << result << endl;
        }
    }
    system("pause");
}