mxnet模型转tensorflow模型
requirements:pip install mxnet tensorflow
1.安装mmdnn
pip install -U git+https://github.com/Microsoft/MMdnn.git@master
2.模型下载
https://pan.baidu.com/s/1If28BkHde4fiuweJrbicVA
3.mxnet模型转IR
python -m mmdnn.conversion._script.convertToIR -f mxnet -n model-symbol.json -w model-0000.params -d tf_model --inputShape 3,112,112
#info:
IR network structure is saved as [tf_model.json].#可视化文件
IR network structure is saved as [tf_model.pb].#网络结构
IR weights are saved as [tf_model.npy].#权重参数
若报错:AttributeError: ‘NoneType’ object has no attribute ‘asnumpy’ line 410
解决:找到python3.6/site-packages/mmdnn/conversion/mxnet/mxnet_parser.py 410 行 修改如下:
weight = self.weight_data.get("fc1_weight").asnumpy().transpose((1, 0))
参考:https://github.com/microsoft/MMdnn/issues/231
4.生成tf_model.py 用于还原神经网络结构
调用tf_model.py中的KitModel函数加载npy权重参数可重新生成原网络框架
python -m mmdnn.conversion._script.IRToCode -f tensorflow --IRModelPath tf_model.pb --IRWeightPath tf_model.npy --dstModelPath tf_model.py
#info:
Parse file [tf_model.pb] with binary format successfully.
Target network code snippet is saved as [tf_model.py].
5.验证模型输出结果是否一致
5.1 test_mxnet.py
import mxnet as mx
from tensorflow.contrib.keras.api.keras.preprocessing import image
import numpy as np
from collections import namedtuple
Batch = namedtuple('Batch', ['data'])
ctx = mx.cpu(0)
#加载模型
sym, arg_params, aux_params = mx.model.load_checkpoint('mobile/model', 0) #mobile文件夹下为 model-symbol.json -w model-0000.params
mod = mx.mod.Module(symbol = sym, context= ctx, label_names= None)
mod.bind(for_training=False, data_shapes=[('data', (1, 3, 112, 112))], label_shapes= mod._label_shapes)
mod.set_params(arg_params, aux_params, allow_missing= True)
path = 'face.jpeg'
img = image.load_img(path, target_size = (112, 112))
img = image.img_to_array(img)
img = img[..., ::-1]
img = np.expand_dims(img, 0).transpose((0,3,1,2))
mod.forward(Batch([mx.nd.array(img)]))
prob = mod.get_outputs()[0].asnumpy()
prob = np.squeeze(prob)
print(prob)
执行命令
python test_mxnet.py
#info:
[17:39:57] src/nnvm/legacy_json_util.cc:209: Loading symbol saved by previous version v1.0.0. Attempting to upgrade...
[17:39:57] src/nnvm/legacy_json_util.cc:217: Symbol successfully upgraded!
[ 0.344491 0.10190611 -0.24501216 0.6819046 0.88096315 0.347766
-0.94702303 -0.67586213 -0.43900824 0.81431276 -0.4899036 -0.43025514
-0.50644076 -0.27366892 0.63601595 -0.5352368 0.13765731 0.40842316
0.76525426 -0.8959755 0.42129532 -0.38290668 0.02023177 -0.14840017
0.9108279 -0.27738237 -0.6017331 -0.214954 0.37644073 0.48894417
-0.8824417 0.31846505 0.19936565 0.27296835 1.5621403 0.4327985
-0.6486908 -0.23494942 -0.8708738 -0.77051663 0.09255238 -1.1803752
-0.17184262 0.2543226 -0.19088541 -0.26873437 0.9160875 -0.18985008
-0.4793183 -0.32987356 -1.3085973 1.2959319 -0.00581244 0.12396478
1.2034996 0.0991946 -1.9225345 0.92873436 -0.285992 0.11249313
-1.4562801 1.8767762 -1.2222489 -0.03905598 0.5152731 0.04876914
0.04671988 -0.32384786 -0.88341314 0.58193505 -0.7378911 -0.3082042
0.22141728 0.7255646 0.24394599 0.6563271 -0.46760473 -0.38698462
-0.11467619 -0.9940818 -1.1298056 1.015201 0.03592067 0.6738027
-0.5814839 0.1565634 -0.06737513 -1.040216 -0.9286871 -0.11091176
-0.66596293 0.03736701 -0.35337982 -0.4175317 -0.47258058 -0.62383175
-0.86612004 -0.5230916 -1.7838901 0.08661752 -0.02590845 0.23406455
0.77719927 1.4410776 0.41925532 0.4560187 -0.02141571 0.7005563
-0.58727044 -0.39757103 1.2808248 -1.1874324 -0.27268586 -0.82443166
0.39704558 -1.2778002 -0.52762616 -0.26455742 1.2137026 0.04997367
0.05591454 1.0264031 1.5093948 -0.5634581 -1.1715719 -0.646347
0.6021179 0.6725963 ]
5.1 test_tensorflow.py
from __future__ import absolute_import
import argparse
import numpy as np
from six import text_type as _text_type
from tensorflow.contrib.keras.api.keras.preprocessing import image
import tensorflow as tf
parser = argparse.ArgumentParser()
parser.add_argument('-n', type=_text_type, default='kitModel',
help='Network structure file name.')
parser.add_argument('-w', type=_text_type, required=True,
help='Network weights file name')
parser.add_argument('--image', '-i',
type=_text_type, help='Test image path.',
default="face.jpeg")
args = parser.parse_args()
if args.n.endswith('.py'):
args.n = args.n[:-3]
model_converted = __import__(args.n).KitModel(args.w)
input_tf, model_tf = model_converted
img = image.load_img(args.image, target_size = (112, 112))
img = image.img_to_array(img)
img = img[..., ::-1]
input_data = np.expand_dims(img, 0)
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
predict = sess.run(model_tf, feed_dict = {input_tf : input_data})
print(predict)
执行命令
python test_tensorflow.py -n tf_model.py -w tf_model.npy -i face.jpeg
#info:
2019-10-25 17:51:37.745502: I tensorflow/core/common_runtime/process_util.cc:71] Creating new thread pool with default inter op setting: 12. Tune using inter_op_parallelism_threads for best performance.
[[ 0.3444912 0.10190725 -0.24501228 0.6819044 0.88096356 0.3477651
-0.9470245 -0.67586106 -0.43900767 0.8143126 -0.48990446 -0.43025535
-0.50643945 -0.27366814 0.63601726 -0.5352377 0.13765681 0.40842274
0.7652553 -0.8959763 0.42129317 -0.38290572 0.02023016 -0.14840023
0.91082776 -0.27738187 -0.60173315 -0.2149537 0.37644142 0.48894492
-0.8824413 0.3184655 0.19936629 0.2729676 1.5621389 0.4327973
-0.6486915 -0.23494866 -0.87087345 -0.77051604 0.09255352 -1.180374
-0.17184272 0.25432315 -0.19088425 -0.26873374 0.91608876 -0.18985137
-0.4793172 -0.3298719 -1.308598 1.2959337 -0.00581198 0.12396422
1.2034999 0.09919477 -1.9225347 0.92873377 -0.28599226 0.11249284
-1.4562793 1.876776 -1.2222495 -0.03905648 0.5152732 0.04876836
0.04672025 -0.32384863 -0.8834132 0.581934 -0.7378913 -0.30820462
0.22141635 0.72556514 0.2439455 0.6563256 -0.46760577 -0.38698506
-0.1146768 -0.9940842 -1.1298054 1.015199 0.03592021 0.67380327
-0.58148336 0.15656358 -0.06737413 -1.0402167 -0.9286856 -0.11091161
-0.66596127 0.03736706 -0.35337985 -0.41753066 -0.47258082 -0.62383235
-0.8661205 -0.52309173 -1.7838898 0.08661895 -0.02590791 0.23406385
0.7771991 1.4410769 0.41925538 0.45601875 -0.02141583 0.70055544
-0.587271 -0.3975702 1.2808259 -1.1874334 -0.27268624 -0.8244319
0.39704552 -1.2778007 -0.5276267 -0.2645575 1.2137012 0.04997464
0.05591418 1.0264043 1.5093954 -0.5634565 -1.1715722 -0.6463482
0.60211945 0.67259526]]
6.基于tf_model.npy和tf_model.py文件,固化参数,生成PB文件
freeze_graph.py
import tensorflow as tf
import tf_model as tf_fun
def netWork():
model=tf_fun.KitModel("./tf_model.npy") #调用tf_model.py中的KitModel函数加载npy权重参数
return model
def freeze_graph(output_graph):
output_node_names = "output"
data,fc1=netWork()
fc1=tf.identity(fc1,name="output")
graph = tf.get_default_graph() # 获得默认的图
input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
output_graph_def = tf.graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=input_graph_def, # 等于:sess.graph_def
output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开
with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型
f.write(output_graph_def.SerializeToString()) # 序列化输出
if __name__ == '__main__':
freeze_graph("frozen_model.pb")
print("finish!")
python freeze_graph.py
#info:
Instructions for updating:
Use tf.compat.v1.graph_util.extract_sub_graph
finish!
测试pb模型
test_pb.py
from tensorflow.contrib.keras.api.keras.preprocessing import image
import tensorflow as tf
import numpy as np
with tf.gfile.FastGFile('frozen_model.pb','rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
img = image.load_img('face.jpeg', target_size = (112, 112))
img = image.img_to_array(img)
img = img[..., ::-1]
input_data = np.expand_dims(img, 0)
init = tf.global_variables_initializer()
sess.run(init)
#print(tf.get_default_graph().get_operations()) 根据输出值,可知输入tensor名为data 全连接层输出的tensor名为ouput
#op = sess.graph.get_tensor_by_name('fc1/add_1:0')
#predict = sess.run(op, feed_dict = {'data:0' : input_data})
op = sess.graph.get_tensor_by_name('output:0')
predict = sess.run(op, feed_dict = {'data:0' : input_data})
print(predict)
python test_pb.py
#info:
[[ 0.3444914 0.10190733 -0.2450121 0.68190414 0.8809633 0.3477656
-0.94702375 -0.6758606 -0.43900838 0.8143138 -0.48990518 -0.43025577
-0.5064386 -0.2736677 0.6360168 -0.5352382 0.13765849 0.40842175
0.7652537 -0.8959745 0.42129484 -0.3829043 0.02023116 -0.14839967
0.9108265 -0.27738202 -0.6017342 -0.21495399 0.37644026 0.48894358
-0.88244045 0.31846407 0.19936593 0.2729677 1.5621401 0.4327974
-0.6486902 -0.23494998 -0.8708729 -0.7705149 0.09255301 -1.1803752
-0.17184293 0.25432175 -0.19088468 -0.2687335 0.91608775 -0.18984997
-0.4793176 -0.32987317 -1.3085989 1.295933 -0.00581315 0.12396483
1.2034997 0.09919491 -1.9225343 0.9287349 -0.28599253 0.11249258
-1.4562799 1.876776 -1.222248 -0.03905599 0.5152738 0.04876841
0.04672143 -0.32384768 -0.88341135 0.58193433 -0.73789096 -0.30820417
0.22141536 0.7255656 0.24394391 0.65632653 -0.4676048 -0.38698527
-0.11467646 -0.9940818 -1.1298064 1.0152006 0.03592108 0.6738041
-0.5814836 0.1565624 -0.06737386 -1.0402162 -0.92868716 -0.1109117
-0.66596234 0.03736827 -0.35337996 -0.41753033 -0.4725821 -0.623832
-0.86612093 -0.5230911 -1.7838906 0.0866183 -0.02590806 0.23406453
0.77719873 1.4410781 0.41925526 0.45601833 -0.02141543 0.7005538
-0.5872698 -0.39757127 1.2808269 -1.1874311 -0.2726869 -0.8244321
0.39704597 -1.2778006 -0.5276277 -0.26455772 1.2137022 0.04997511
0.05591565 1.0264047 1.5093955 -0.56345737 -1.1715721 -0.64634734
0.60211873 0.67259526]]
7. .pb模型转换为.tflite模型供tensorflow lite调用(移动端部署)
pb2tflite.py
import tensorflow as tf
convert=tf.lite.TFLiteConverter.from_frozen_graph("frozen_model.pb",input_arrays=["data"],
output_arrays=["output"],
input_shapes={"data":[1,112,112,3]})
convert.post_training_quantize=True #是否量化
tflite_model=convert.convert()
open("mobilenet.tflite","wb").write(tflite_model)
print("finish!")
python pb2tflite.py
info: 文件夹中生成mobilenet.tflite
本博客所有文章除特别声明外,均采用 CC BY-SA 3.0协议 。转载请注明出处!