如题。
tensorflow 2(keras)做的模型,如何制作成nb文件,并让vim3 npu进行加速?
目前我们这边有新的sdk,支持转换keras模型,最近会放出来。
用现在你手里的sdk也是可以的,但是你需要将keras模型转换为tensorflow pb模型,这个是有通用方法的,你搜一下应该可以找到,然后在进行转换。
@ThinkBird 你可以试一下我这个代码
from keras.models import load_model
import tensorflow as tf
import argparse
import sys
import os
import os.path as osp
from keras import backend as K
def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_"):
if osp.exists(output_dir) == False:
os.mkdir(output_dir)
out_nodes = []
for i in range(len(h5_model.outputs)):
out_nodes.append(out_prefix + str(i + 1))
tf.identity(h5_model.output[i],out_prefix + str(i + 1))
sess = K.get_session()
from tensorflow.python.framework import graph_util,graph_io
init_graph = sess.graph.as_graph_def()
main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)
if __name__ == "__main__" :
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", help="the h5 model path ")
parser.add_argument("--model_output_path", help="the pb model outout path ")
args = parser.parse_args()
if args.model_path :
h5_model_path = args.model_path
else :
sys.exit("model_path not found ! Please use this format : --model_path")
if args.model_output_path :
cut_out = args.model_output_path.rfind('/')
h5_model_name = args.model_output_path[cut_out+1:]
h5_model_dir = args.model_output_path[:cut_out]
else:
cut_out = h5_model_path.rfind('/')
suffix_data = h5_model_path[cut_out+1:]
h5_model_name = suffix_data[:-2] + 'pb'
h5_model_dir = './'
h5_model = load_model(h5_model_path)
h5_to_pb(h5_model, h5_model_dir, h5_model_name)
但是TF2做的模型不保证能转换,有一些新的接口是不支持的
目前sdk中软件包比较陈旧,目前tf2模型需要再用tf1重复实现,极为麻烦。
急切需要较新的软件支持,比如 tensorflow 2.1,希望能尽快释出新版本SDK
按照此方法可以从h5转换到pb文件,但是在执行 0_import_model.sh时,有错误提示:
Traceback (most recent call last):
File "convertensorflow.py", line 62, in <module>
File "convertensorflow.py", line 58, in main
File "acuitylib/app/importer/import_tensorflow.py", line 81, in run
File "acuitylib/converter/convert_tf.py", line 97, in __init__
File "acuitylib/converter/tensorflowloader.py", line 53, in __init__
AttributeError: 'NoneType' object has no attribute 'op'
[3472] Failed to execute script convertensorflow
附:
0_import_model.sh 文件内容如下:
#!/bin/bash
NAME=mnist
ACUITY_PATH=../bin/
convert_caffe=${ACUITY_PATH}convertcaffe
convert_tf=${ACUITY_PATH}convertensorflow
convert_tflite=${ACUITY_PATH}convertflite
convert_darknet=${ACUITY_PATH}convertdarknet
convert_onnx=${ACUITY_PATH}convertonnx
$convert_tf \
--tf-pb ./model/mnist_model.pb \
--inputs input \
--input-size-list '784' \
--outputs output_1 \
--net-output ${NAME}.json \
--data-output ${NAME}.data
pb文件部分内容如下:
node {
name: "strided_slice"
op: "StridedSlice"
input: "output/Softmax"
input: "strided_slice/stack"
input: "strided_slice/stack_1"
input: "strided_slice/stack_2"
attr {
key: "Index"
value {
type: DT_INT32
}
}
attr {
key: "T"
value {
type: DT_FLOAT
}
}
attr {
key: "begin_mask"
value {
i: 0
}
}
attr {
key: "ellipsis_mask"
value {
i: 0
}
}
attr {
key: "end_mask"
value {
i: 0
}
}
attr {
key: "new_axis_mask"
value {
i: 0
}
}
attr {
key: "shrink_axis_mask"
value {
i: 1
}
}
}
node {
name: "output_1"
op: "Identity"
input: "strided_slice"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
另外,h5转pb使用的是tensorflow 2 的v1 兼容模式,代码如下:
from tensorflow.compat.v1.keras.models import load_model
import tensorflow.compat.v1 as tf
import argparse
import sys
import os
import os.path as osp
from tensorflow.compat.v1.keras import backend as K
tf.compat.v1.disable_eager_execution()
def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_"):
if osp.exists(output_dir) == False:
os.mkdir(output_dir)
out_nodes = []
for i in range(len(h5_model.outputs)):
out_nodes.append(out_prefix + str(i + 1))
tf.identity(h5_model.output[i],out_prefix + str(i + 1))
sess = K.get_session()
from tensorflow.python.framework import graph_util,graph_io
init_graph = sess.graph.as_graph_def()
main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes)
graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False)
if __name__ == "__main__" :
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", help="the h5 model path ")
parser.add_argument("--model_output_path", help="the pb model outout path ")
args = parser.parse_args()
if args.model_path :
h5_model_path = args.model_path
else :
sys.exit("model_path not found ! Please use this format : --model_path")
if args.model_output_path :
cut_out = args.model_output_path.rfind('/')
h5_model_name = args.model_output_path[cut_out+1:]
h5_model_dir = args.model_output_path[:cut_out]
else:
cut_out = h5_model_path.rfind('/')
suffix_data = h5_model_path[cut_out+1:]
h5_model_name = suffix_data[:-2] + 'pb'
h5_model_dir = './'
h5_model = load_model(h5_model_path)
h5_to_pb(h5_model, h5_model_dir, h5_model_name)
新的SDK也是基于1.13.2的,暂时是没有2.x版本的。