Ask Your Question
0

Fail using MobileSSD caffe model

asked 2018-05-14 02:02:01 -0600

Shay Weissman gravatar image

updated 2018-05-14 02:08:44 -0600

berak gravatar image

I downloaded MobileNetSSD_deploy.caffemodel from https://github.com/chuanqi305/MobileN... I renamed deploy.prototxt to MobileNetSSD_deploy.txt. I used the supplied routine in OpenCvDnn example - I succeeded using Google Net caffe model. I changed the code to load the model - succeeded loading. I changed to :

Mat inputBlob = blobFromImage(img, 1, Size(*300,300*), Scalar(104, 117, 123));   //Convert Mat to batch of images

cv::Mat prob = net.forward("detection_out");

I get an assertion at getMaxClass(prob, &classId, &classProb) Seem like prob not created at the forward. I am using 3.4.1 on windows 7 with vs2015.

The full code:

    cv::String modelTxt = "MobileNetSSD_deploy.txt";
cv::String modelBin = "MobileNetSSD_deploy.caffemodel";
Net net = dnn::readNetFromCaffe(modelTxt, modelBin);
if (net.empty())
{
    std::cerr << "Can't load network by using the following files: " << std::endl;
    std::cerr << "prototxt:   " << modelTxt << std::endl;
    std::cerr << "caffemodel: " << modelBin << std::endl;
    std::cerr << "bvlc_googlenet.caffemodel can be downloaded here:" << std::endl;
    std::cerr << "http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel" << std::endl;
    exit(-1);
}
std::cout << "Network loaded successfuly\n";
cv::String imageFile = "Goose.bmp";
//      cv::String imageFile = "space_shuttle.jpg";
Mat img = imread(imageFile, IMREAD_UNCHANGED);
int nChan = img.channels();
cout << nChan;

if (img.empty())
{
    std::cerr << "Can't read image from the file: " << imageFile << std::endl;
    exit(-1);
}
//GoogLeNet accepts only 224x224 RGB-images
Mat inputBlob = blobFromImage(img, 1, Size(300, 300), Scalar(104, 117, 123));   //Convert Mat to batch of images

net.setInput(inputBlob, "data");        //set the network input
cv::Mat prob = net.forward("detection_out");                          //compute output
int classId;
double classProb;
getMaxClass(prob, &classId, &classProb);//find the best class
std::vector<String> classNames = readClassNames();
std::cout << "Best class: #" << classId << " '" << classNames.at(classId) << "'" << std::endl;
std::cout << "Probability: " << classProb * 100 << "%" << std::endl;
edit retag flag offensive close merge delete

Comments

btw, ScaleFactor should be 0.007843f, not 1

berak gravatar imageberak ( 2018-05-14 02:30:28 -0600 )edit

1 answer

Sort by ยป oldest newest most voted
1

answered 2018-05-14 02:22:43 -0600

berak gravatar image

updated 2018-05-14 02:24:59 -0600

mobilenet-ssd is a object detection network, not a classification one.

it also has only 10 object classes, not 1000, like the googlenet.

you will need code similar to this, to extract bounding boxes / probabilities / labels:

(please also see the sample here)

const char* classNames[] = {"background",
                        "aeroplane", "bicycle", "bird", "boat",
                        "bottle", "bus", "car", "cat", "chair",
                        "cow", "diningtable", "dog", "horse",
                        "motorbike", "person", "pottedplant",
                        "sheep", "sofa", "train", "tvmonitor"};



Mat detection = net.forward("detection_out");   
Mat detectionMat(detection.size[2], detection.size[3], CV_32F, detection.ptr<float>());

float confidenceThreshold = 0.5;
for(int i = 0; i < detectionMat.rows; i++)
{
    float confidence = detectionMat.at<float>(i, 2);

    if(confidence > confidenceThreshold)
    {
        size_t objectClass = (size_t)(detectionMat.at<float>(i, 1));

        float xLeftBottom = detectionMat.at<float>(i, 3) * frame.cols;
        float yLeftBottom = detectionMat.at<float>(i, 4) * frame.rows;
        float xRightTop = detectionMat.at<float>(i, 5) * frame.cols;
        float yRightTop = detectionMat.at<float>(i, 6) * frame.rows;

            ostringstream ss;
            ss << confidence;
            String conf(ss.str());

            Rect object((int)xLeftBottom, (int)yLeftBottom,
                        (int)(xRightTop - xLeftBottom),
                        (int)(yRightTop - yLeftBottom));

            rectangle(frame, object, Scalar(0, 255, 0));
            String label = String(classNames[objectClass]) + ": " + conf;
            int baseLine = 0;
            Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
            rectangle(frame, Rect(Point(xLeftBottom, yLeftBottom - labelSize.height),
                                  Size(labelSize.width, labelSize.height + baseLine)),
                      Scalar(255, 255, 255), CV_FILLED);
            putText(frame, label, Point(xLeftBottom, yLeftBottom),
                    FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0,0,0));
    }
}

imshow("detections", frame);
waitKey();
edit flag offensive delete link more

Question Tools

1 follower

Stats

Asked: 2018-05-14 02:02:01 -0600

Seen: 806 times

Last updated: May 14 '18