Ask Your Question
0

OpenCV Error: Assertion failed (!roots.empty()) in function predict

asked 2018-11-15 09:10:54 -0600

paddy74 gravatar image

I am attempting to use OpenCV's RTrees to create a random forest classifier. However, I am not able to successfully load a model saved using the save function. An example of the code I used is as follows:

import numpy as np
import cv2

def train(samples, class_labels, save_file=None):
    """
    samples      : np.ndarray of type np.float32
    class_labels :  np.ndarray of type int
    save_file    : str of the absolute path to the save file
    """
    model = cv2.ml.RTrees_create()

    # Set paremeters
    model.setMaxDepth(20)
    model.setActiveVarCount(0)
    term_type, n_trees, epsilon = cv2.TERM_CRITERIA_MAX_ITER, 128, 1
    model.setTermCriteria((term_type, n_trees, epsilon))

    train_data = cv2.ml.TrainData_create(samples=samples,
                                         layout=cv2.ml.ROW_SAMPLE,
                                         responses=class_labels)
    model.train(trainData=train_data)

    if save_file:
        model.save(save_file)

def test(samples, load_file):
        """
        samples   : np.ndarray of type np.float32
        load_file : str of the absolute path to the load file
        """
        model = cv2.ml.RTrees_create()
        model.load(load_file)

        _ret, responses = model.predict(samples)
        return responses.ravel()

if __name__ == '__main__':
    model_file = '/home/user/bin/tmp/m.xml'

    x = np.random.randint(0, 5, (1000, 15))
    y = np.random.randint(0, 4, 1000)
    train(x, y, model_file)

    x_test = np.random.randint(0, 5, (100, 15))
    test(x_test, model_file)

The error is as follows:

Traceback (most recent call last):
  File "tmp.py", line 45, in <module>
    test(x_test, model_file)
  File "tmp.py", line 34, in test
    _ret, responses = model.predict(samples)
cv2.error: OpenCV(3.4.2) /io/opencv/modules/ml/src/tree.cpp:1498: error: (-215:Assertion failed) !roots.empty() in function 'predict'
edit retag flag offensive close merge delete

1 answer

Sort by ยป oldest newest most voted
1

answered 2018-11-15 09:16:24 -0600

berak gravatar image

easy answer: the load() function returns a new object, it does not use the one you call it on ,so the old one is still empty ! (same for all other ml classes, btw). see for your self:

>>> help(cv2.ml.RTrees_load)
Help on built-in function RTrees_load:

RTrees_load(...)
    RTrees_load(filepath[, nodeName]) -> retval
    .   @brief Loads and creates a serialized RTree from a file
    .   *
    .   * Use RTree::save to serialize and store an RTree to disk.
    .   * Load the RTree from this file again, by calling this function with the path to the file.
    .   * Optionally specify the node for the file containing the classifier
    .   *
    .   * @Param filepath path to serialized RTree
    .   * @Param nodeName name of node containing the classifier

so, you have to use:

    model = cv2.ml.RTrees_load(load_file)

NOT:

    model = cv2.ml.RTrees_create()
    model.load(load_file)
edit flag offensive delete link more

Question Tools

1 follower

Stats

Asked: 2018-11-15 09:10:54 -0600

Seen: 712 times

Last updated: Nov 15 '18