Ask Your Question
0

How to import TensorFlow model with flatten layer?

asked 2018-01-05 06:33:01 -0600

macc.n gravatar image

Hello everybody,

I have created a CNN with Keras. The code of the net is:

model = Sequential()

model.add(Conv2D(32, (3,3), data_format='channels_last', input_shape=(48, 32, 3), name='data'))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Flatten())
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(1))
model.add(Activation('sigmoid', name='result_class'))

Then, using this script, I have converted the .h5 file, created by Keras, in .pb.

Now I want to import the model using OpenCV (3.3), but when I execute the following code

Net net;
try {
    net = dnn::readNetFromTensorflow(model_path);
}
catch (cv::Exception& e) {
    cerr << "Exception: " << e.what() << endl;
    if (net.empty()) {
        cerr << "Can't load the model" << endl;
    }
}

I get this error:

OpenCV Error: Unspecified error (Unknown layer type Shape in op flatten_1/Shape) in populateNet, file /home/nicola/Scrivania/opencv-3.4.0/opencv-3.4.0/modules/dnn/src/tensorflow/tf_importer.cpp, line 1487
Exception: /home/nicola/Scrivania/opencv-3.4.0/opencv-3.4.0/modules/dnn/src/tensorflow/tf_importer.cpp:1487: error: (-2) Unknown layer type Shape in op flatten_1/Shape in function populateNet
Can't load the model

It seems that OpenCV can't handle flatten layer, am I right? Is there a way to import my net?

Thanks for your help.

edit retag flag offensive close merge delete

Comments

Hi, this answer comes a bit late. But my similar problem was solved by this stackoverflow answer: Use Keras model with Flatten layer inside OpenCV 3

SEbert gravatar imageSEbert ( 2018-06-12 01:59:36 -0600 )edit

1 answer

Sort by ยป oldest newest most voted
0

answered 2019-01-06 20:07:53 -0600

nvnnghia gravatar image

Check this repo: https://github.com/nvnnghia/opencv-Im... . It use VGG16 to do image classification for CIFAR10 and do inference by opencv using tensorflow graph

edit flag offensive delete link more

Question Tools

1 follower

Stats

Asked: 2018-01-05 06:33:01 -0600

Seen: 1,479 times

Last updated: Jan 06 '19