欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

tensorflow得到图的所有节点名称以及得到节点输出

程序员文章站 2022-05-30 16:39:15
...
 # 默认图的所有节点名称
 # tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
 #图graph的所有节点名称
 # tensor_name_list = [tensor.name for tensor in graph.as_graph_def().node]
 # print(tensor_name_list)
   #数据处理
   string = '设置一个随机种子'
   char_list = ['[CLS]'] + list(string) +['[SEP]']
   #不做masked处理
   mask_list = [1] * (len(string)+2)
   #不做分词处理
   seg_list = [0] * (len(string)+2)
   
   # 根据bert的词表做一个char_to_id的操作
   # 未登录词会报错,更改报错代码使未登录词时为'[UNK]'
   # 也可以自己实现
   token = tokenization.FullTokenizer(vocab_file='chinese_L-12_H-768_A-12/vocab.txt')
   char_list = token.convert_tokens_to_ids(char_list)

   char_lists = [char_list]
   mask_lists = [mask_list]
   (seg_lists = [seg_list]
   
   input_ids = sess.graph.get_tensor_by_name('input_ids:0')
   input_mask = sess.graph.get_tensor_by_name('input_masks:0')
   segment_ids = sess.graph.get_tensor_by_name('segment_ids:0')
   
    # bert12层transformer,取最后一层的输出
    output = sess.graph.get_tensor_by_name('bert/encoder/layer_11/output/LayerNorm/batchnorm/add_1:0')
   
   feed_data = {input_ids: np.asarray(char_lists), input_mask: np.asarray(mask_lists), segment_ids: np.asarray(seg_lists)}
   
   embedding = sess.run(output, feed_dict=feed_data)
   #bert输出向量结果分批次没有节点,这里reshape成bert_model.get_sequence_output()的形状
   embedding = np.reshape(embedding, (len(char_lists), len(char_lists[0]), -1))
相关标签: 算法