tensorflow得到图的所有节点名称以及得到节点输出
程序员文章站
2022-05-30 16:39:15
...
# 默认图的所有节点名称
# tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
#图graph的所有节点名称
# tensor_name_list = [tensor.name for tensor in graph.as_graph_def().node]
# print(tensor_name_list)
#数据处理
string = '设置一个随机种子'
char_list = ['[CLS]'] + list(string) +['[SEP]']
#不做masked处理
mask_list = [1] * (len(string)+2)
#不做分词处理
seg_list = [0] * (len(string)+2)
# 根据bert的词表做一个char_to_id的操作
# 未登录词会报错,更改报错代码使未登录词时为'[UNK]'
# 也可以自己实现
token = tokenization.FullTokenizer(vocab_file='chinese_L-12_H-768_A-12/vocab.txt')
char_list = token.convert_tokens_to_ids(char_list)
char_lists = [char_list]
mask_lists = [mask_list]
(seg_lists = [seg_list]
input_ids = sess.graph.get_tensor_by_name('input_ids:0')
input_mask = sess.graph.get_tensor_by_name('input_masks:0')
segment_ids = sess.graph.get_tensor_by_name('segment_ids:0')
# bert12层transformer,取最后一层的输出
output = sess.graph.get_tensor_by_name('bert/encoder/layer_11/output/LayerNorm/batchnorm/add_1:0')
feed_data = {input_ids: np.asarray(char_lists), input_mask: np.asarray(mask_lists), segment_ids: np.asarray(seg_lists)}
embedding = sess.run(output, feed_dict=feed_data)
#bert输出向量结果分批次没有节点,这里reshape成bert_model.get_sequence_output()的形状
embedding = np.reshape(embedding, (len(char_lists), len(char_lists[0]), -1))
下一篇: webpack打包前端项目入门