1 | initial version |
@MennoK, You may see that TensorFlow makes a flattening through Reshape
layer using dynamically calculated shape:
import tensorflow as tf
# Read the graph.
with tf.gfile.FastGFile('model.pb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Session() as sess:
# Restore session
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
tf.summary.FileWriter('logs', graph_def)
OpenCV can't do it. There are two possible ways: replace a flatten node by reshape during graph creation or help OpenCV to import it using text graph.
Run optimize_for_inference.py
tool to remove some extra nodes from the graph:
python ~/tensorflow/tensorflow/python/tools/optimize_for_inference.py \
--input model.pb \
--output opt_model.pb \
--input_names input_1 \
--output_names predictions/Softmax
Run the following script to make a text graph representation:
import tensorflow as tf
# Read the graph.
with tf.gfile.FastGFile('model.pb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# Remove Const nodes.
for i in reversed(range(len(graph_def.node))):
if graph_def.node[i].op == 'Const':
del graph_def.node[i]
for attr in ['T', 'data_format', 'Tshape', 'N', 'Tidx', 'Tdim',
'use_cudnn_on_gpu', 'Index', 'Tperm', 'is_training',
'Tpaddings']:
if attr in graph_def.node[i].attr:
del graph_def.node[i].attr[attr]
# Save as text.
tf.train.write_graph(graph_def, "", "model.pbtxt", as_text=True)
Open model.pbtxt
and remove nodes with names flatten/Shape
, flatten/strided_slice
, flatten/Prod
, flatten/stack
. Replace a node
node {
name: "flatten/Reshape"
op: "Reshape"
input: "block5_pool/MaxPool"
input: "flatten/stack"
}
on
node {
name: "flatten/Reshape"
op: "Flatten"
input: "block5_pool/MaxPool"
}
Now you can remove open_graph.pb
and use both model.pb
and model.pbtxt
in OpenCV:
import cv2 as cv
import numpy as np
net = cv.dnn.readNetFromTensorflow('model.pb', 'model.pbtxt')
inp = np.random.standard_normal([1, 3, 224, 224]).astype(np.float32)
net.setInput(inp)
out = net.forward()
2 | No.2 Revision |
@MennoK, You may see that TensorFlow makes a flattening through Reshape
layer using dynamically calculated shape:
import tensorflow as tf
# Read the graph.
with tf.gfile.FastGFile('model.pb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Session() as sess:
# Restore session
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
tf.summary.FileWriter('logs', graph_def)
OpenCV can't do it. There are two possible ways: replace a flatten node by reshape during graph creation or help OpenCV to import it using text graph.
Run optimize_for_inference.py
tool to remove some extra nodes from the graph:
python ~/tensorflow/tensorflow/python/tools/optimize_for_inference.py \
--input model.pb \
--output opt_model.pb \
--input_names input_1 \
--output_names predictions/Softmax
Run the following script to make a text graph representation:
import tensorflow as tf
# Read the graph.
with tf.gfile.FastGFile('model.pb') tf.gfile.FastGFile('opt_model.pb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# Remove Const nodes.
for i in reversed(range(len(graph_def.node))):
if graph_def.node[i].op == 'Const':
del graph_def.node[i]
for attr in ['T', 'data_format', 'Tshape', 'N', 'Tidx', 'Tdim',
'use_cudnn_on_gpu', 'Index', 'Tperm', 'is_training',
'Tpaddings']:
if attr in graph_def.node[i].attr:
del graph_def.node[i].attr[attr]
# Save as text.
tf.train.write_graph(graph_def, "", "model.pbtxt", as_text=True)
Open model.pbtxt
and remove nodes with names flatten/Shape
, flatten/strided_slice
, flatten/Prod
, flatten/stack
. Replace a node
node {
name: "flatten/Reshape"
op: "Reshape"
input: "block5_pool/MaxPool"
input: "flatten/stack"
}
on
node {
name: "flatten/Reshape"
op: "Flatten"
input: "block5_pool/MaxPool"
}
Now you can remove
and use both open_graph.pbopt_model.pbmodel.pb
and model.pbtxt
in OpenCV:
import cv2 as cv
import numpy as np
net = cv.dnn.readNetFromTensorflow('model.pb', 'model.pbtxt')
inp = np.random.standard_normal([1, 3, 224, 224]).astype(np.float32)
net.setInput(inp)
out = net.forward()