欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

搞清楚TF中的Embedding

程序员文章站 2022-06-15 15:36:01
...

参考的资料:
详解TF中的Embedding操作
官网Embedding教程

1.为什么引入Embedding

在对词汇表中单词进行编码时
若使用one-hot编码,存在以下问题:

  • 编码效率低下,假设我们的词汇表中有 10,000 个单词。为了对每个单词进行独热编码,我们将创建一个其中 99.99% 的元素都为零的向量。

若使用唯一的数字编码每个单词,则存在以下问题:

  • 整数编码是任意的,不会捕捉单词之间的任何关系,比如近义词或相同类型的词可能存在某种联系。
  • 对于要解释的模型而言,整数编码颇具挑战。例如,线性分类器针对每个特征学习一个权重。由于任何两个单词的相似性与其编码的相似性之间都没有关系,因此这种特征权重组合没有意义。

因此我们使用Embedding,单词嵌入向量为我们提供了一种使用高效、密集表示的方法,其中相似的单词具有相似的编码。重要的是,我们不必手动指定此编码。嵌入向量是浮点值的密集向量(向量的长度是您指定的参数)。它们是可以训练的参数(模型在训练过程中学习的权重,与模型学习密集层权重的方法相同),无需手动为嵌入向量指定值。8 维的单词嵌入向量(对于小型数据集)比较常见,而在处理大型数据集时最多可达 1024 维。维度更高的嵌入向量可以捕获单词之间的细粒度关系,但需要更多的数据来学习。
搞清楚TF中的Embedding

2. Embedding的工作原理

Embedding的实质就是全连接层
搞清楚TF中的Embedding

在TF2中我们建立Embedding的表达式为:

tf.keras.layers.Embedding(vocab_size, embedding_dim)

其中vocab-size是需要编码的词的数量,可理解为上图中左边节点数量。embedding_dim则为右边节点数量。
假设一个特征有5个取值,即one-hot后变成5维,我们将其转换为embedding,其实就是将其one-hot后接入一个dence层。

例:
假设embedding层为Embedding(5, 10)
(1,) --> (1,5) --> (1,5) * (5, 10) --> (1,10)

2.1 由TF1进行Embedding过程解析

在tf1.x中,我们使用embedding_lookup函数来实现emedding,代码如下:

# embedding
embedding = tf.constant(
        [[0.21,0.41,0.51,0.11]],
        [0.22,0.42,0.52,0.12],
        [0.23,0.43,0.53,0.13],
        [0.24,0.44,0.54,0.14]],dtype=tf.float32)

feature_batch = tf.constant([2,3,1,0])

get_embedding1 = tf.nn.embedding_lookup(embedding,feature_batch)

上面的过程为:
搞清楚TF中的Embedding
注意这里的维度的变化,假设我们的feature_batch 是 1维的tensor,长度为4,而embedding的长度为4,那么得到的结果是 4 * 4 的,同理,假设feature_batch是2 *4的,embedding_lookup后的结果是2 * 4 * 4。后面我们在观察结果。

上文说过,embedding层其实是一个全连接神经网络层,那么其过程等价于:

搞清楚TF中的Embedding
可以得到下面的代码:

embedding = tf.constant(
    [
        [0.21,0.41,0.51,0.11],
        [0.22,0.42,0.52,0.12],
        [0.23,0.43,0.53,0.13],
        [0.24,0.44,0.54,0.14]
    ],dtype=tf.float32)

feature_batch = tf.constant([2,3,1,0])
feature_batch_one_hot = tf.one_hot(feature_batch,depth=4)
get_embedding2 = tf.matmul(feature_batch_one_hot,embedding)

二者是否一致呢?我们通过代码来验证一下:

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    embedding1,embedding2 = sess.run([get_embedding1,get_embedding2])
    print(embedding1)
    print(embedding2)

得到的结果为:
搞清楚TF中的Embedding

后续补充…

总结

进行Embedding研究因为在模型训练中遇到以下问题:
加载了原有的检查点后使用不同词汇量进行训练,发现会出现embedding层不一致的问题,先研究到这,后续再想办法解决怎么才能使用不同词汇量数据不产生embedding冲突

可能解决上述问题的思路有:

  • embedding的vocab-size大小的影响,是否可以设置很大的值,这样应该不会产生此问题。但此举又有什么副作用?