mxnet模型转tensorflow模型

requirements:pip install mxnet tensorflow

1.安装mmdnn

pip install -U git+https://github.com/Microsoft/MMdnn.git@master

2.模型下载

https://pan.baidu.com/s/1If28BkHde4fiuweJrbicVA

3.mxnet模型转IR
python -m mmdnn.conversion._script.convertToIR -f mxnet -n model-symbol.json -w model-0000.params -d tf_model --inputShape 3,112,112 
#info:
IR network structure is saved as [tf_model.json].#可视化文件
IR network structure is saved as [tf_model.pb].#网络结构
IR weights are saved as [tf_model.npy].#权重参数

若报错:AttributeError: ‘NoneType’ object has no attribute ‘asnumpy’ line 410

解决:找到python3.6/site-packages/mmdnn/conversion/mxnet/mxnet_parser.py 410 行 修改如下:

weight = self.weight_data.get("fc1_weight").asnumpy().transpose((1, 0))

参考:https://github.com/microsoft/MMdnn/issues/231

4.生成tf_model.py 用于还原神经网络结构

调用tf_model.py中的KitModel函数加载npy权重参数可重新生成原网络框架

python -m mmdnn.conversion._script.IRToCode -f tensorflow --IRModelPath tf_model.pb --IRWeightPath tf_model.npy --dstModelPath tf_model.py
#info:
Parse file [tf_model.pb] with binary format successfully.
Target network code snippet is saved as [tf_model.py].
5.验证模型输出结果是否一致

5.1 test_mxnet.py

import mxnet as mx
from tensorflow.contrib.keras.api.keras.preprocessing import image
import numpy as np

from collections import namedtuple
Batch = namedtuple('Batch', ['data'])

ctx = mx.cpu(0)

#加载模型
sym, arg_params, aux_params = mx.model.load_checkpoint('mobile/model', 0) #mobile文件夹下为 model-symbol.json -w model-0000.params
mod = mx.mod.Module(symbol = sym, context= ctx, label_names= None)
mod.bind(for_training=False, data_shapes=[('data', (1, 3, 112, 112))], label_shapes= mod._label_shapes)
mod.set_params(arg_params, aux_params, allow_missing= True)

path = 'face.jpeg'
img = image.load_img(path, target_size = (112, 112))
img = image.img_to_array(img)
img = img[..., ::-1]

img = np.expand_dims(img, 0).transpose((0,3,1,2))

mod.forward(Batch([mx.nd.array(img)]))
prob = mod.get_outputs()[0].asnumpy()
prob = np.squeeze(prob)
print(prob)

执行命令

python test_mxnet.py 
#info:
[17:39:57] src/nnvm/legacy_json_util.cc:209: Loading symbol saved by previous version v1.0.0. Attempting to upgrade...
[17:39:57] src/nnvm/legacy_json_util.cc:217: Symbol successfully upgraded!
[ 0.344491    0.10190611 -0.24501216  0.6819046   0.88096315  0.347766
 -0.94702303 -0.67586213 -0.43900824  0.81431276 -0.4899036  -0.43025514
 -0.50644076 -0.27366892  0.63601595 -0.5352368   0.13765731  0.40842316
  0.76525426 -0.8959755   0.42129532 -0.38290668  0.02023177 -0.14840017
  0.9108279  -0.27738237 -0.6017331  -0.214954    0.37644073  0.48894417
 -0.8824417   0.31846505  0.19936565  0.27296835  1.5621403   0.4327985
 -0.6486908  -0.23494942 -0.8708738  -0.77051663  0.09255238 -1.1803752
 -0.17184262  0.2543226  -0.19088541 -0.26873437  0.9160875  -0.18985008
 -0.4793183  -0.32987356 -1.3085973   1.2959319  -0.00581244  0.12396478
  1.2034996   0.0991946  -1.9225345   0.92873436 -0.285992    0.11249313
 -1.4562801   1.8767762  -1.2222489  -0.03905598  0.5152731   0.04876914
  0.04671988 -0.32384786 -0.88341314  0.58193505 -0.7378911  -0.3082042
  0.22141728  0.7255646   0.24394599  0.6563271  -0.46760473 -0.38698462
 -0.11467619 -0.9940818  -1.1298056   1.015201    0.03592067  0.6738027
 -0.5814839   0.1565634  -0.06737513 -1.040216   -0.9286871  -0.11091176
 -0.66596293  0.03736701 -0.35337982 -0.4175317  -0.47258058 -0.62383175
 -0.86612004 -0.5230916  -1.7838901   0.08661752 -0.02590845  0.23406455
  0.77719927  1.4410776   0.41925532  0.4560187  -0.02141571  0.7005563
 -0.58727044 -0.39757103  1.2808248  -1.1874324  -0.27268586 -0.82443166
  0.39704558 -1.2778002  -0.52762616 -0.26455742  1.2137026   0.04997367
  0.05591454  1.0264031   1.5093948  -0.5634581  -1.1715719  -0.646347
  0.6021179   0.6725963 ]

5.1 test_tensorflow.py

from __future__ import absolute_import
import argparse
import numpy as np
from six import text_type as _text_type
from tensorflow.contrib.keras.api.keras.preprocessing import image
import tensorflow as tf


parser = argparse.ArgumentParser()
parser.add_argument('-n', type=_text_type, default='kitModel',
                    help='Network structure file name.')
parser.add_argument('-w', type=_text_type, required=True,
                    help='Network weights file name')
parser.add_argument('--image', '-i',
                    type=_text_type, help='Test image path.',
                    default="face.jpeg")

args = parser.parse_args()
if args.n.endswith('.py'):
    args.n = args.n[:-3]

model_converted = __import__(args.n).KitModel(args.w)
input_tf, model_tf = model_converted

img = image.load_img(args.image, target_size = (112, 112))
img = image.img_to_array(img)
img = img[..., ::-1]

input_data = np.expand_dims(img, 0)

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    predict = sess.run(model_tf, feed_dict = {input_tf : input_data})
print(predict)

执行命令

python test_tensorflow.py -n tf_model.py -w tf_model.npy -i face.jpeg
#info:
2019-10-25 17:51:37.745502: I tensorflow/core/common_runtime/process_util.cc:71] Creating new thread pool with default inter op setting: 12. Tune using inter_op_parallelism_threads for best performance.
[[ 0.3444912   0.10190725 -0.24501228  0.6819044   0.88096356  0.3477651
  -0.9470245  -0.67586106 -0.43900767  0.8143126  -0.48990446 -0.43025535
  -0.50643945 -0.27366814  0.63601726 -0.5352377   0.13765681  0.40842274
   0.7652553  -0.8959763   0.42129317 -0.38290572  0.02023016 -0.14840023
   0.91082776 -0.27738187 -0.60173315 -0.2149537   0.37644142  0.48894492
  -0.8824413   0.3184655   0.19936629  0.2729676   1.5621389   0.4327973
  -0.6486915  -0.23494866 -0.87087345 -0.77051604  0.09255352 -1.180374
  -0.17184272  0.25432315 -0.19088425 -0.26873374  0.91608876 -0.18985137
  -0.4793172  -0.3298719  -1.308598    1.2959337  -0.00581198  0.12396422
   1.2034999   0.09919477 -1.9225347   0.92873377 -0.28599226  0.11249284
  -1.4562793   1.876776   -1.2222495  -0.03905648  0.5152732   0.04876836
   0.04672025 -0.32384863 -0.8834132   0.581934   -0.7378913  -0.30820462
   0.22141635  0.72556514  0.2439455   0.6563256  -0.46760577 -0.38698506
  -0.1146768  -0.9940842  -1.1298054   1.015199    0.03592021  0.67380327
  -0.58148336  0.15656358 -0.06737413 -1.0402167  -0.9286856  -0.11091161
  -0.66596127  0.03736706 -0.35337985 -0.41753066 -0.47258082 -0.62383235
  -0.8661205  -0.52309173 -1.7838898   0.08661895 -0.02590791  0.23406385
   0.7771991   1.4410769   0.41925538  0.45601875 -0.02141583  0.70055544
  -0.587271   -0.3975702   1.2808259  -1.1874334  -0.27268624 -0.8244319
   0.39704552 -1.2778007  -0.5276267  -0.2645575   1.2137012   0.04997464
   0.05591418  1.0264043   1.5093954  -0.5634565  -1.1715722  -0.6463482
   0.60211945  0.67259526]]
6.基于tf_model.npy和tf_model.py文件,固化参数,生成PB文件

freeze_graph.py

import tensorflow as tf
import tf_model as tf_fun
def netWork():
    model=tf_fun.KitModel("./tf_model.npy") #调用tf_model.py中的KitModel函数加载npy权重参数
    return model
def freeze_graph(output_graph):
    output_node_names = "output"
    data,fc1=netWork()
    fc1=tf.identity(fc1,name="output")

    graph = tf.get_default_graph()  # 获得默认的图
    input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        output_graph_def = tf.graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
            sess=sess,
            input_graph_def=input_graph_def,  # 等于:sess.graph_def
            output_node_names=output_node_names.split(","))  # 如果有多个输出节点,以逗号隔开

        with tf.gfile.GFile(output_graph, "wb") as f:  # 保存模型
            f.write(output_graph_def.SerializeToString())  # 序列化输出

if __name__ == '__main__':
    freeze_graph("frozen_model.pb")
    print("finish!")

python freeze_graph.py

#info:
Instructions for updating:
Use tf.compat.v1.graph_util.extract_sub_graph
finish!

测试pb模型

test_pb.py


from tensorflow.contrib.keras.api.keras.preprocessing import image
import tensorflow as tf
import numpy as np

with tf.gfile.FastGFile('frozen_model.pb','rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

with tf.Session() as sess:

    img = image.load_img('face.jpeg', target_size = (112, 112))
    img = image.img_to_array(img)
    img = img[..., ::-1]
    input_data = np.expand_dims(img, 0)

    init = tf.global_variables_initializer()
    sess.run(init)
    #print(tf.get_default_graph().get_operations()) 根据输出值,可知输入tensor名为data 全连接层输出的tensor名为ouput
    #op = sess.graph.get_tensor_by_name('fc1/add_1:0')
    #predict = sess.run(op, feed_dict = {'data:0' : input_data})
    op = sess.graph.get_tensor_by_name('output:0')
    predict = sess.run(op, feed_dict = {'data:0' : input_data})
print(predict)

python test_pb.py

#info:
[[ 0.3444914   0.10190733 -0.2450121   0.68190414  0.8809633   0.3477656
  -0.94702375 -0.6758606  -0.43900838  0.8143138  -0.48990518 -0.43025577
  -0.5064386  -0.2736677   0.6360168  -0.5352382   0.13765849  0.40842175
   0.7652537  -0.8959745   0.42129484 -0.3829043   0.02023116 -0.14839967
   0.9108265  -0.27738202 -0.6017342  -0.21495399  0.37644026  0.48894358
  -0.88244045  0.31846407  0.19936593  0.2729677   1.5621401   0.4327974
  -0.6486902  -0.23494998 -0.8708729  -0.7705149   0.09255301 -1.1803752
  -0.17184293  0.25432175 -0.19088468 -0.2687335   0.91608775 -0.18984997
  -0.4793176  -0.32987317 -1.3085989   1.295933   -0.00581315  0.12396483
   1.2034997   0.09919491 -1.9225343   0.9287349  -0.28599253  0.11249258
  -1.4562799   1.876776   -1.222248   -0.03905599  0.5152738   0.04876841
   0.04672143 -0.32384768 -0.88341135  0.58193433 -0.73789096 -0.30820417
   0.22141536  0.7255656   0.24394391  0.65632653 -0.4676048  -0.38698527
  -0.11467646 -0.9940818  -1.1298064   1.0152006   0.03592108  0.6738041
  -0.5814836   0.1565624  -0.06737386 -1.0402162  -0.92868716 -0.1109117
  -0.66596234  0.03736827 -0.35337996 -0.41753033 -0.4725821  -0.623832
  -0.86612093 -0.5230911  -1.7838906   0.0866183  -0.02590806  0.23406453
   0.77719873  1.4410781   0.41925526  0.45601833 -0.02141543  0.7005538
  -0.5872698  -0.39757127  1.2808269  -1.1874311  -0.2726869  -0.8244321
   0.39704597 -1.2778006  -0.5276277  -0.26455772  1.2137022   0.04997511
   0.05591565  1.0264047   1.5093955  -0.56345737 -1.1715721  -0.64634734
   0.60211873  0.67259526]]
7. .pb模型转换为.tflite模型供tensorflow lite调用(移动端部署)

pb2tflite.py

import tensorflow as tf

convert=tf.lite.TFLiteConverter.from_frozen_graph("frozen_model.pb",input_arrays=["data"],
                                                  output_arrays=["output"],
                                                  input_shapes={"data":[1,112,112,3]})
convert.post_training_quantize=True #是否量化
tflite_model=convert.convert()
open("mobilenet.tflite","wb").write(tflite_model)
print("finish!")

python pb2tflite.py

info: 文件夹中生成mobilenet.tflite



深度学习      模型转换

本博客所有文章除特别声明外,均采用 CC BY-SA 3.0协议 。转载请注明出处!