mxnet模型转tensorflow模型

requirements:pip install mxnet tensorflow

1.安装mmdnn

pip install -U git+https://github.com/Microsoft/MMdnn.git@master

2.模型下载

https://pan.baidu.com/s/1If28BkHde4fiuweJrbicVA

3.mxnet模型转IR
  1. python -m mmdnn.conversion._script.convertToIR -f mxnet -n model-symbol.json -w model-0000.params -d tf_model --inputShape 3,112,112
  2. #info:
  3. IR network structure is saved as [tf_model.json].#可视化文件
  4. IR network structure is saved as [tf_model.pb].#网络结构
  5. IR weights are saved as [tf_model.npy].#权重参数

若报错:AttributeError: ‘NoneType’ object has no attribute ‘asnumpy’ line 410

解决:找到python3.6/site-packages/mmdnn/conversion/mxnet/mxnet_parser.py 410 行 修改如下:

  1. weight = self.weight_data.get("fc1_weight").asnumpy().transpose((1, 0))

参考:https://github.com/microsoft/MMdnn/issues/231

4.生成tf_model.py 用于还原神经网络结构

调用tf_model.py中的KitModel函数加载npy权重参数可重新生成原网络框架

  1. python -m mmdnn.conversion._script.IRToCode -f tensorflow --IRModelPath tf_model.pb --IRWeightPath tf_model.npy --dstModelPath tf_model.py
  2. #info:
  3. Parse file [tf_model.pb] with binary format successfully.
  4. Target network code snippet is saved as [tf_model.py].
5.验证模型输出结果是否一致

5.1 test_mxnet.py

  1. import mxnet as mx
  2. from tensorflow.contrib.keras.api.keras.preprocessing import image
  3. import numpy as np
  4. from collections import namedtuple
  5. Batch = namedtuple('Batch', ['data'])
  6. ctx = mx.cpu(0)
  7. #加载模型
  8. sym, arg_params, aux_params = mx.model.load_checkpoint('mobile/model', 0) #mobile文件夹下为 model-symbol.json -w model-0000.params
  9. mod = mx.mod.Module(symbol = sym, context= ctx, label_names= None)
  10. mod.bind(for_training=False, data_shapes=[('data', (1, 3, 112, 112))], label_shapes= mod._label_shapes)
  11. mod.set_params(arg_params, aux_params, allow_missing= True)
  12. path = 'face.jpeg'
  13. img = image.load_img(path, target_size = (112, 112))
  14. img = image.img_to_array(img)
  15. img = img[..., ::-1]
  16. img = np.expand_dims(img, 0).transpose((0,3,1,2))
  17. mod.forward(Batch([mx.nd.array(img)]))
  18. prob = mod.get_outputs()[0].asnumpy()
  19. prob = np.squeeze(prob)
  20. print(prob)

执行命令

  1. python test_mxnet.py
  2. #info:
  3. [17:39:57] src/nnvm/legacy_json_util.cc:209: Loading symbol saved by previous version v1.0.0. Attempting to upgrade...
  4. [17:39:57] src/nnvm/legacy_json_util.cc:217: Symbol successfully upgraded!
  5. [ 0.344491 0.10190611 -0.24501216 0.6819046 0.88096315 0.347766
  6. -0.94702303 -0.67586213 -0.43900824 0.81431276 -0.4899036 -0.43025514
  7. -0.50644076 -0.27366892 0.63601595 -0.5352368 0.13765731 0.40842316
  8. 0.76525426 -0.8959755 0.42129532 -0.38290668 0.02023177 -0.14840017
  9. 0.9108279 -0.27738237 -0.6017331 -0.214954 0.37644073 0.48894417
  10. -0.8824417 0.31846505 0.19936565 0.27296835 1.5621403 0.4327985
  11. -0.6486908 -0.23494942 -0.8708738 -0.77051663 0.09255238 -1.1803752
  12. -0.17184262 0.2543226 -0.19088541 -0.26873437 0.9160875 -0.18985008
  13. -0.4793183 -0.32987356 -1.3085973 1.2959319 -0.00581244 0.12396478
  14. 1.2034996 0.0991946 -1.9225345 0.92873436 -0.285992 0.11249313
  15. -1.4562801 1.8767762 -1.2222489 -0.03905598 0.5152731 0.04876914
  16. 0.04671988 -0.32384786 -0.88341314 0.58193505 -0.7378911 -0.3082042
  17. 0.22141728 0.7255646 0.24394599 0.6563271 -0.46760473 -0.38698462
  18. -0.11467619 -0.9940818 -1.1298056 1.015201 0.03592067 0.6738027
  19. -0.5814839 0.1565634 -0.06737513 -1.040216 -0.9286871 -0.11091176
  20. -0.66596293 0.03736701 -0.35337982 -0.4175317 -0.47258058 -0.62383175
  21. -0.86612004 -0.5230916 -1.7838901 0.08661752 -0.02590845 0.23406455
  22. 0.77719927 1.4410776 0.41925532 0.4560187 -0.02141571 0.7005563
  23. -0.58727044 -0.39757103 1.2808248 -1.1874324 -0.27268586 -0.82443166
  24. 0.39704558 -1.2778002 -0.52762616 -0.26455742 1.2137026 0.04997367
  25. 0.05591454 1.0264031 1.5093948 -0.5634581 -1.1715719 -0.646347
  26. 0.6021179 0.6725963 ]

5.1 test_tensorflow.py

  1. from __future__ import absolute_import
  2. import argparse
  3. import numpy as np
  4. from six import text_type as _text_type
  5. from tensorflow.contrib.keras.api.keras.preprocessing import image
  6. import tensorflow as tf
  7. parser = argparse.ArgumentParser()
  8. parser.add_argument('-n', type=_text_type, default='kitModel',
  9. help='Network structure file name.')
  10. parser.add_argument('-w', type=_text_type, required=True,
  11. help='Network weights file name')
  12. parser.add_argument('--image', '-i',
  13. type=_text_type, help='Test image path.',
  14. default="face.jpeg")
  15. args = parser.parse_args()
  16. if args.n.endswith('.py'):
  17. args.n = args.n[:-3]
  18. model_converted = __import__(args.n).KitModel(args.w)
  19. input_tf, model_tf = model_converted
  20. img = image.load_img(args.image, target_size = (112, 112))
  21. img = image.img_to_array(img)
  22. img = img[..., ::-1]
  23. input_data = np.expand_dims(img, 0)
  24. with tf.Session() as sess:
  25. init = tf.global_variables_initializer()
  26. sess.run(init)
  27. predict = sess.run(model_tf, feed_dict = {input_tf : input_data})
  28. print(predict)

执行命令

  1. python test_tensorflow.py -n tf_model.py -w tf_model.npy -i face.jpeg
  2. #info:
  3. 2019-10-25 17:51:37.745502: I tensorflow/core/common_runtime/process_util.cc:71] Creating new thread pool with default inter op setting: 12. Tune using inter_op_parallelism_threads for best performance.
  4. [[ 0.3444912 0.10190725 -0.24501228 0.6819044 0.88096356 0.3477651
  5. -0.9470245 -0.67586106 -0.43900767 0.8143126 -0.48990446 -0.43025535
  6. -0.50643945 -0.27366814 0.63601726 -0.5352377 0.13765681 0.40842274
  7. 0.7652553 -0.8959763 0.42129317 -0.38290572 0.02023016 -0.14840023
  8. 0.91082776 -0.27738187 -0.60173315 -0.2149537 0.37644142 0.48894492
  9. -0.8824413 0.3184655 0.19936629 0.2729676 1.5621389 0.4327973
  10. -0.6486915 -0.23494866 -0.87087345 -0.77051604 0.09255352 -1.180374
  11. -0.17184272 0.25432315 -0.19088425 -0.26873374 0.91608876 -0.18985137
  12. -0.4793172 -0.3298719 -1.308598 1.2959337 -0.00581198 0.12396422
  13. 1.2034999 0.09919477 -1.9225347 0.92873377 -0.28599226 0.11249284
  14. -1.4562793 1.876776 -1.2222495 -0.03905648 0.5152732 0.04876836
  15. 0.04672025 -0.32384863 -0.8834132 0.581934 -0.7378913 -0.30820462
  16. 0.22141635 0.72556514 0.2439455 0.6563256 -0.46760577 -0.38698506
  17. -0.1146768 -0.9940842 -1.1298054 1.015199 0.03592021 0.67380327
  18. -0.58148336 0.15656358 -0.06737413 -1.0402167 -0.9286856 -0.11091161
  19. -0.66596127 0.03736706 -0.35337985 -0.41753066 -0.47258082 -0.62383235
  20. -0.8661205 -0.52309173 -1.7838898 0.08661895 -0.02590791 0.23406385
  21. 0.7771991 1.4410769 0.41925538 0.45601875 -0.02141583 0.70055544
  22. -0.587271 -0.3975702 1.2808259 -1.1874334 -0.27268624 -0.8244319
  23. 0.39704552 -1.2778007 -0.5276267 -0.2645575 1.2137012 0.04997464
  24. 0.05591418 1.0264043 1.5093954 -0.5634565 -1.1715722 -0.6463482
  25. 0.60211945 0.67259526]]
6.基于tf_model.npy和tf_model.py文件,固化参数,生成PB文件

freeze_graph.py

  1. import tensorflow as tf
  2. import tf_model as tf_fun
  3. def netWork():
  4. model=tf_fun.KitModel("./tf_model.npy") #调用tf_model.py中的KitModel函数加载npy权重参数
  5. return model
  6. def freeze_graph(output_graph):
  7. output_node_names = "output"
  8. data,fc1=netWork()
  9. fc1=tf.identity(fc1,name="output")
  10. graph = tf.get_default_graph() # 获得默认的图
  11. input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
  12. init = tf.global_variables_initializer()
  13. with tf.Session() as sess:
  14. sess.run(init)
  15. output_graph_def = tf.graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
  16. sess=sess,
  17. input_graph_def=input_graph_def, # 等于:sess.graph_def
  18. output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开
  19. with tf.gfile.GFile(output_graph, "wb") as f: # 保存模型
  20. f.write(output_graph_def.SerializeToString()) # 序列化输出
  21. if __name__ == '__main__':
  22. freeze_graph("frozen_model.pb")
  23. print("finish!")

python freeze_graph.py

  1. #info:
  2. Instructions for updating:
  3. Use tf.compat.v1.graph_util.extract_sub_graph
  4. finish!

测试pb模型

test_pb.py

  1. from tensorflow.contrib.keras.api.keras.preprocessing import image
  2. import tensorflow as tf
  3. import numpy as np
  4. with tf.gfile.FastGFile('frozen_model.pb','rb') as f:
  5. graph_def = tf.GraphDef()
  6. graph_def.ParseFromString(f.read())
  7. tf.import_graph_def(graph_def, name='')
  8. with tf.Session() as sess:
  9. img = image.load_img('face.jpeg', target_size = (112, 112))
  10. img = image.img_to_array(img)
  11. img = img[..., ::-1]
  12. input_data = np.expand_dims(img, 0)
  13. init = tf.global_variables_initializer()
  14. sess.run(init)
  15. #print(tf.get_default_graph().get_operations()) 根据输出值,可知输入tensor名为data 全连接层输出的tensor名为ouput
  16. #op = sess.graph.get_tensor_by_name('fc1/add_1:0')
  17. #predict = sess.run(op, feed_dict = {'data:0' : input_data})
  18. op = sess.graph.get_tensor_by_name('output:0')
  19. predict = sess.run(op, feed_dict = {'data:0' : input_data})
  20. print(predict)

python test_pb.py

  1. #info:
  2. [[ 0.3444914 0.10190733 -0.2450121 0.68190414 0.8809633 0.3477656
  3. -0.94702375 -0.6758606 -0.43900838 0.8143138 -0.48990518 -0.43025577
  4. -0.5064386 -0.2736677 0.6360168 -0.5352382 0.13765849 0.40842175
  5. 0.7652537 -0.8959745 0.42129484 -0.3829043 0.02023116 -0.14839967
  6. 0.9108265 -0.27738202 -0.6017342 -0.21495399 0.37644026 0.48894358
  7. -0.88244045 0.31846407 0.19936593 0.2729677 1.5621401 0.4327974
  8. -0.6486902 -0.23494998 -0.8708729 -0.7705149 0.09255301 -1.1803752
  9. -0.17184293 0.25432175 -0.19088468 -0.2687335 0.91608775 -0.18984997
  10. -0.4793176 -0.32987317 -1.3085989 1.295933 -0.00581315 0.12396483
  11. 1.2034997 0.09919491 -1.9225343 0.9287349 -0.28599253 0.11249258
  12. -1.4562799 1.876776 -1.222248 -0.03905599 0.5152738 0.04876841
  13. 0.04672143 -0.32384768 -0.88341135 0.58193433 -0.73789096 -0.30820417
  14. 0.22141536 0.7255656 0.24394391 0.65632653 -0.4676048 -0.38698527
  15. -0.11467646 -0.9940818 -1.1298064 1.0152006 0.03592108 0.6738041
  16. -0.5814836 0.1565624 -0.06737386 -1.0402162 -0.92868716 -0.1109117
  17. -0.66596234 0.03736827 -0.35337996 -0.41753033 -0.4725821 -0.623832
  18. -0.86612093 -0.5230911 -1.7838906 0.0866183 -0.02590806 0.23406453
  19. 0.77719873 1.4410781 0.41925526 0.45601833 -0.02141543 0.7005538
  20. -0.5872698 -0.39757127 1.2808269 -1.1874311 -0.2726869 -0.8244321
  21. 0.39704597 -1.2778006 -0.5276277 -0.26455772 1.2137022 0.04997511
  22. 0.05591565 1.0264047 1.5093955 -0.56345737 -1.1715721 -0.64634734
  23. 0.60211873 0.67259526]]
7. .pb模型转换为.tflite模型供tensorflow lite调用(移动端部署)

pb2tflite.py

  1. import tensorflow as tf
  2. convert=tf.lite.TFLiteConverter.from_frozen_graph("frozen_model.pb",input_arrays=["data"],
  3. output_arrays=["output"],
  4. input_shapes={"data":[1,112,112,3]})
  5. convert.post_training_quantize=True #是否量化
  6. tflite_model=convert.convert()
  7. open("mobilenet.tflite","wb").write(tflite_model)
  8. print("finish!")

python pb2tflite.py

info: 文件夹中生成mobilenet.tflite



深度学习      模型转换

本博客所有文章除特别声明外,均采用 CC BY-SA 3.0协议 。转载请注明出处!