Ask Your Question

Revision history [back]

click to hide/show revision 1
initial version

Low accuracy of SVM on Android

Hello guys,

I have an Android project that uses the face detection feature (Cascade Classifier). After recognizing the face, I trim the eyes and use a descriptor (Kaze) in conjunction with Bag of Words and SVM to recognize open and closed eyes, left and right. I performed the SVM training with a set of 1600 images, 400 of each type.

The training is done in C#, and after finishing the training, I test with some images, and all are classified correctly.

However, when importing the resulting dictionary and svm files, using the camera results have been very bad.

Here are the images and files generated through training:

https://github.com/caiocanalli/OpenCV-Resources/tree/master

Attached, the C#, Java, and C ++ code I'm using.

Any suggestion of how to improve preaching and welcome.

Thank you.

public class Training
{
    private const string Path = "C:/Users/Administrator/Desktop/Data";

    private readonly KAZE _kaze;
    private readonly BFMatcher _bfMatcher;
    private readonly BOWKMeansTrainer _bowKMeansTrainer;
    private readonly BOWImgDescriptorExtractor _bowImgDescriptorExtractor;
    private readonly Dictionary<int, string> _images;

    private Mat _kazeDescriptors;
    private Mat _bowDescriptors;

    private Matrix<int> label;

    public Training()
    {
        int dictionarySize = 10000;

        _kaze = new KAZE(true, true, 0.00001F);
        _bfMatcher = new BFMatcher(DistanceType.L2);

        _bowKMeansTrainer = new BOWKMeansTrainer(
            dictionarySize,
            new MCvTermCriteria(200, 0.00001),
            1,
            KMeansInitType.PPCenters);

        _bowImgDescriptorExtractor =
            new BOWImgDescriptorExtractor(_kaze, _bfMatcher);

        _kazeDescriptors = new Mat();
        _bowDescriptors = new Mat();

        _images = new Dictionary<int, string>
        {
            { 1, "/closedLeftEyes" },
            { 2, "/openLeftEyes" },

            //{ 1, "/closedRightEyes" },
            //{ 2, "/openRightEyes" }
        };
    }

    public void Start()
    {
        // Compute KAZE descriptors

        System.Console.WriteLine("Compute KAZE descriptors...");

        foreach (var item in _images)
        {
            var files = Directory.GetFiles(
                Path + item.Value, "*.png");

            System.Console.WriteLine("Directory: " + item.Value + ". Size: " + files.Length);

            foreach (var file in files)
            {
                var image = new Image<Gray, byte>(file);

                Mat descriptors = new Mat();
                MKeyPoint[] keyPoints = _kaze.Detect(image);

                _kaze.Compute(image,
                    new VectorOfKeyPoint(keyPoints), descriptors);

                _kazeDescriptors.PushBack(descriptors);

                System.Console.WriteLine("Imagem: " + file);
            }
        }

        _bowKMeansTrainer.Add(_kazeDescriptors);

        // Cluster dictionary

        System.Console.WriteLine("Cluster dictionary...");

        Mat dictionary = new Mat();
        _bowKMeansTrainer.Cluster(dictionary);

        _bowImgDescriptorExtractor.SetVocabulary(dictionary);

        FileStorage fs = new FileStorage(
            "dictionary_left.xml", FileStorage.Mode.Write);
        fs.Write(dictionary, "dictionary");
        fs.ReleaseAndGetString();

        //FileStorage fs = new FileStorage(
        //    "dictionary_right.xml", FileStorage.Mode.Write);
        //fs.Write(dictionary, "dictionary");
        //fs.ReleaseAndGetString();

        // Compute BOW descriptors

        System.Console.WriteLine("Compute BOW descriptors...");

        var labels = new List<int>();

        foreach (var item in _images)
        {
            var files = Directory.GetFiles(
                Path + item.Value, "*.png");

            System.Console.WriteLine("Directory: " + item.Value + ". Size: " + files.Length);

            foreach (var file in files)
            {
                var image = new Image<Gray, byte>(file);

                Mat descriptors = new Mat();
                MKeyPoint[] keyPoints = _kaze.Detect(image);

                _bowImgDescriptorExtractor.Compute(image,
                    new VectorOfKeyPoint(keyPoints), descriptors);

                _bowDescriptors.PushBack(descriptors);
                labels.Add(item.Key);

                System.Console.WriteLine("Image: " + file);
            }
        }

        label = new Matrix<int>(labels.ToArray());

        // Train SVM

        System.Console.WriteLine("Train SVM...");

        SVM svm = new SVM();
        svm.SetKernel(SVM.SvmKernelType.Rbf);
        svm.Type = SVM.SvmType.CSvc;

        svm.TermCriteria = new MCvTermCriteria(400, 0.00001);

        TrainData trainData = new TrainData(
            _bowDescriptors, 
            Emgu.CV.ML.MlEnum.DataLayoutType.RowSample,
            label);

        System.Console.WriteLine("C: " + svm.C + " Gamma: " + svm.Gamma);

        bool result = svm.TrainAuto(trainData);

        System.Console.WriteLine("C: " + svm.C + " Gamma: " + svm.Gamma);

        svm.Save("svm_left.xml");
        //svm.Save("svm_right.xml");

        if (!result)
            throw new Exception("Unable to perform SVM training.");

        var test = Directory.GetFiles(
            Path + "/test", "*.png");

        foreach (var file in test)
        {
            System.Console.WriteLine("Image: " + file);

            var descriptors = new Mat();
            var image = new Image<Gray, Byte>(file);

            MKeyPoint[] keypoints = _kaze.Detect(image);
            _bowImgDescriptorExtractor.Compute(
                image,
                new VectorOfKeyPoint(keypoints),
                descriptors);

            var response = svm.Predict(descriptors);

            System.Console.WriteLine($"Result {response}");
        }
    }
}

Java

public class MainActivity
    extends Activity
    implements CameraBridgeViewBase.CvCameraViewListener2 {

static String TAG = "InMotion";

static {
    System.loadLibrary("native-lib");
}

CameraBridgeViewBase cameraView;

BaseLoaderCallback loaderCallback = new BaseLoaderCallback(this) {
    @Override
    public void onManagerConnected(int status) {
        switch (status) {
            case BaseLoaderCallback.SUCCESS:
                onLoaderCallbackSuccess();
                break;
            default:
                super.onManagerConnected(status);
        }
    }
};

/*
    Lifecycle
*/

@Override
protected void onCreate(Bundle savedInstanceState) {
    super.onCreate(savedInstanceState);
    setContentView(R.layout.activity_main);

    cameraView = findViewById(R.id.camera_view);
    cameraView.setCvCameraViewListener(this);
}

@Override
protected void onResume() {
    super.onResume();
    initOpenCV();
}

@Override
protected void onPause() {
    super.onPause();
    disableView();
}

@Override
protected void onDestroy() {
    super.onDestroy();
    disableView();
}

/*
    Contract
*/

@Override
public void onCameraViewStarted(int width, int height) {

}

@Override
public void onCameraViewStopped() {

}

@Override
public Mat onCameraFrame(CameraBridgeViewBase.CvCameraViewFrame inputFrame) {

    fit(inputFrame.gray().getNativeObjAddr());

    return inputFrame.gray();
}

/*
    Utils
*/

private void initOpenCV() {
    if (OpenCVLoader.initDebug())
        loaderCallback.onManagerConnected(LoaderCallbackInterface.SUCCESS);
    else
        OpenCVLoader.initAsync(OpenCVLoader.OPENCV_VERSION_3_4_0, this, loaderCallback);
}

private void onLoaderCallbackSuccess() {

    loadClassifier();
    loadSVM();
    loadExtractor();

    enableView();
}

private void loadClassifier() {

    File file = moveResource(
            R.raw.lbpcascade_frontalface,
            "files",
            "lbpcascade_frontalface.xml");

    if (file == null)
        return;

    loadClassifier(file.getAbsolutePath());
}

private void loadSVM() {

    File left = moveResource(R.raw.svm_left, "files", "leftSVM.xml");

    if (left == null)
        return;

    File right = moveResource(R.raw.svm_right, "files", "rightSVM.xml");

    if (right == null)
        return;

    loadSVM(left.getAbsolutePath(), right.getAbsolutePath());
}

private void loadExtractor() {

    File left = moveResource(R.raw.dictionary_left, "files", "leftDictionary.xml");

    if (left == null)
        return;

    File right = moveResource(R.raw.dictionary_right, "files", "rightDictionary.xml");

    if (right == null)
        return;

    loadExtractor(left.getAbsolutePath(), right.getAbsolutePath());
}

private void enableView() {
    if (cameraView != null) {
        cameraView.enableFpsMeter();
        cameraView.setCameraIndex(1);
        cameraView.enableView();
    }
}

private void disableView() {
    if (cameraView != null)
        cameraView.disableView();
}

private File moveResource(int resource, String directory, String fileName) {
    try {
        InputStream is = getResources().openRawResource(resource);
        File dir = getDir(directory, Context.MODE_PRIVATE);
        File file = new File(dir, fileName);
        FileOutputStream os = new FileOutputStream(file);

        byte[] buffer = new byte[8192];
        int bytesRead;
        while ((bytesRead = is.read(buffer)) != -1) {
            os.write(buffer, 0, bytesRead);
        }
        is.close();
        os.close();

        return file;

    } catch (Exception e) {
        Log.d(TAG, String.valueOf(e));
    }

    return null;
}

/*
    Native
*/

public native void loadClassifier(String classifierPath);

public native void loadSVM(String leftSVMPath, String rightSVMPath);

public native void loadExtractor(String leftDictionaryPath, String rightDictionaryPath);

public native void fit(long imageAddr); }

C++

#include <jni.h> 
#include <string> 
#include <opencv/cv.h> 
#include <opencv2/imgproc.hpp> 
#include <opencv2/features2d.hpp> 
#include <opencv2/ml/ml.hpp> 
#include <android/log.h> 
#include <opencv/cv.hpp> 

using namespace std; 
using namespace cv;
using namespace ml;

#define APPNAME "InMotion"

extern "C" {

int absoluteFaceSize = 0;
float relativeFaceSize = 0.2f;

Ptr<BFMatcher> bFMatcher;
Ptr<CascadeClassifier> classifier;

// Left

Mat leftDictionary;
Ptr<DescriptorExtractor> leftExtractor;
Ptr<BOWImgDescriptorExtractor> leftBOWImgDescriptorExtractor;
Ptr<SVM> leftSVM;

// Right

Mat rightDictionary;
Ptr<DescriptorExtractor> rightExtractor;
Ptr<BOWImgDescriptorExtractor> rightBOWImgDescriptorExtractor;
Ptr<SVM> rightSVM;

// Utils

void message(char *message, int prio = ANDROID_LOG_VERBOSE) {
    __android_log_print(prio, APPNAME, message, 1);
}

void calculateSize(int height) {

    if (absoluteFaceSize == 0)
        if (round(height * relativeFaceSize) > 0)
            absoluteFaceSize = round(height * relativeFaceSize);
}

int predictLeft(Mat mat) {

    int response = 0;
    Mat bowDescriptor;
    vector<KeyPoint> keyPoints;

    try {

        leftExtractor->detect(mat, keyPoints);
        leftBOWImgDescriptorExtractor->compute(mat, keyPoints, bowDescriptor);
        response = (int) leftSVM->predict(bowDescriptor);

    } catch (exception &e) {
        message((char *) e.what());
    }

    return response;
}

int predictRight(Mat mat) {

    int response = 0;
    Mat bowDescriptor;
    vector<KeyPoint> keyPoints;

    try {

        rightExtractor->detect(mat, keyPoints);
        rightBOWImgDescriptorExtractor->compute(mat, keyPoints, bowDescriptor);
        response = (int) rightSVM->predict(bowDescriptor);

    } catch (exception &e) {
        message((char *) e.what());
    }

    return response;
}

// Native

JNIEXPORT void JNICALL Java_br_com_caiocanalli_inmotion_MainActivity_loadClassifier(
        JNIEnv *env,
        jobject,
        jstring classifierPath) {

    const char *path = env->GetStringUTFChars(classifierPath, 0);

    classifier = makePtr<CascadeClassifier>(path);
}

JNIEXPORT void JNICALL Java_br_com_caiocanalli_inmotion_MainActivity_loadSVM(
        JNIEnv *env,
        jobject,
        jstring leftSVMPath,
        jstring rightSVMPath) {

    // Left

    const char *leftPath = env->GetStringUTFChars(leftSVMPath, 0);

    leftSVM = Algorithm::load<SVM>(leftPath);

    // Right

    const char *rightPath = env->GetStringUTFChars(rightSVMPath, 0);

    rightSVM = Algorithm::load<SVM>(rightPath);
}

JNIEXPORT void JNICALL Java_br_com_caiocanalli_inmotion_MainActivity_loadExtractor(
        JNIEnv *env,
        jobject,
        jstring leftDictionaryPath,
        jstring rightDictionaryPath) {

    bFMatcher = BFMatcher::create(NORM_L2);

    // Left

    const char *leftPath = env->GetStringUTFChars(leftDictionaryPath, 0);

    FileStorage leftFS(leftPath, FileStorage::READ);
    leftFS["dictionary"] >> leftDictionary;
    leftFS.release();

    leftExtractor = KAZE::create(true, true, 0.00001);

    leftBOWImgDescriptorExtractor = new BOWImgDescriptorExtractor(
            leftExtractor, bFMatcher);

    leftBOWImgDescriptorExtractor->setVocabulary(leftDictionary);

    // Right

    const char *rightPath = env->GetStringUTFChars(rightDictionaryPath, 0);

    FileStorage rightFS(rightPath, FileStorage::READ);
    rightFS["dictionary"] >> rightDictionary;
    rightFS.release();

    rightExtractor = KAZE::create(true, true, 0.00001);

    rightBOWImgDescriptorExtractor = new BOWImgDescriptorExtractor(
            rightExtractor, bFMatcher);

    rightBOWImgDescriptorExtractor->setVocabulary(rightDictionary);
}

JNIEXPORT void JNICALL Java_br_com_caiocanalli_inmotion_MainActivity_fit(
        JNIEnv,
        jobject,
        jlong frameAddr) {

    Mat &frame = *(Mat *) frameAddr;

    ///////////////////////////////////////////////////////////////////////////
    // Detector
    ///////////////////////////////////////////////////////////////////////////

    std::vector<cv::Rect> faces;

    calculateSize(frame.rows);

    classifier->detectMultiScale(
            frame, faces, 1.1, 2, 2,
            Size(absoluteFaceSize, absoluteFaceSize));

    if(faces.size() == 0)
        return;

    Rect face = faces[0];
    int resizeW, resizeH = 0;

    ///////////////////////////////////////////////////////////////////////////
    // Face
    ///////////////////////////////////////////////////////////////////////////

    int faceX = face.x;
    int faceY = face.y;
    int faceW = face.width;
    int faceH = face.height;

    Rect faceArea(faceX, faceY, faceW, faceH);
    rectangle(frame, faceArea, Scalar(255, 0, 0), 4, 8, 0);

    ///////////////////////////////////////////////////////////////////////////
    // LeftEye
    ///////////////////////////////////////////////////////////////////////////

    int leftX = (faceX + faceW / 2);
    int leftY = faceY + (faceH / 4.8);
    int leftW = (faceW / 2 - 60);
    int leftH = (faceH / 3);

    Rect leftEyeArea(leftX, leftY, leftW, leftH);
    rectangle(frame, leftEyeArea, Scalar(255, 0, 0), 4, 8, 0);

    Mat leftEye(frame, leftEyeArea);

    resizeW = 50;
    resizeH = (leftEye.rows * resizeW / leftEye.cols);
    cv::resize(leftEye, leftEye, Size(resizeW, resizeH));

    int leftEyeResult = predictLeft(leftEye);

    if (leftEyeResult == 1)
        message("Left closed");
    else if (leftEyeResult == 2)
        message("Left open");

    leftEye.release();

    ///////////////////////////////////////////////////////////////////////////
    // RightEye
    ///////////////////////////////////////////////////////////////////////////

    int rightX = (faceX + 60);
    int rightY = faceY + (faceH / 4.8);
    int rightW = (faceX + faceW / 2) - rightX;
    int rightH = (faceH / 3);

    Rect rightEyeArea(rightX, rightY, rightW, rightH);
    rectangle(frame, rightEyeArea, Scalar(255, 0, 0), 4, 8, 0);

    Mat rightEye(frame, rightEyeArea);

    resizeW = 50;
    resizeH = (rightEye.rows * resizeW / rightEye.cols);
    cv::resize(rightEye, rightEye, Size(resizeW, resizeH));

    int rightEyeResult = predictRight(rightEye);

    if (rightEyeResult == 1)
        message("Right closed");
    else if (rightEyeResult == 2)
        message("Left open");
}  }