欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

tf.nn.embedding_lookup()函数

程序员文章站 2024-02-29 21:17:22
...

一、tf.nn.embedding_lookup()

tf.nn.embedding_lookup函数的用法主要是选取一个张量里面索引对应的元素。tf.nn.embedding_lookup(tensor, id):tensor就是输入张量,id就是张量对应的索引。

二、代码

import tensorflow as tf
import numpy as np

input_ids = tf.placeholder(dtype=tf.int32, shape=[None])

embedding = tf.Variable(np.identity(5, dtype=np.int32))
input_embedding = tf.nn.embedding_lookup(embedding, input_ids)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
print(embedding.eval())
print(sess.run(input_embedding, feed_dict={input_ids:[1, 2, 3, 0, 3, 2, 1]}))

tf.nn.embedding_lookup()函数

(1)

print(embedding.eval())

tf.nn.embedding_lookup()函数

(2)

print(sess.run(input_embedding, feed_dict={input_ids:[1, 2, 3, 0, 3, 2, 1]}))

tf.nn.embedding_lookup()函数

根据input_ids中的id,寻找embedding中的对应元素。比如,input_ids=[1,3,5],则找出embedding中下标为1,3,5的向量组成一个矩阵返回。

三、Embedding原理

参考:https://blog.csdn.net/laolu1573/article/details/77170407

tf.nn.embedding_lookup()函数