欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

tensorflow常用模型开发小函数总结(持续补充中)

程序员文章站 2022-05-11 11:10:18
...


使用版本:tensorflow2.2

tf.gather

tf.gather(
params, indices, validate_indices=None, axis=None, batch_dims=0, name=None
)
官方API
用法:
在查询高维稀疏特征的对应嵌入向量值时,非常好用。高维稀疏特征就是类似于商品id这种有非常多离散取值的特征字段,通常建模时会先对它进行编码,转化为0,1,2,3…这样子,然后就可以使用tf.gather获取嵌入向量了。

#定义特征1的嵌入矩阵
emb_mat1=tf.constant(tf.random.normal([5,3]),name='embeddig_matrix1')

#定义特征2的嵌入矩阵
emb_mat2=tf.constant(tf.random.normal([8,3]),name='embeddig_matrix2')

#当样本为(2,4),即特征1取值为2,特征2取值为4时,计算嵌入向量表示
sample=(2,4)
emb1=tf.gather(emb_mat1,[sample[0]],axis=0)
emb2=tf.gather(emb_mat2,[sample[1]],axis=0)

#将嵌入向量进行连接,作为DNN或RNN的输入
vec=tf.concat([emb1,emb2],axis=1)

tf.concat

官方API
用法:
基本上做任何模型都需要做向量拼接,绝对的高频使用函数。
例子可以参考上面那个。

tf.expand_dims

官方API
用法:
给tensor做维度扩展,也是非常常用的小函数。

a=tf.constant([[1,2,3],[4,5,6]],dtype=tf.float32)

tf.expand_dims(a,axis=1)

#扩展结果,shape由2*3变为2*1*3了
<tf.Tensor: shape=(2, 1, 3), dtype=float32, numpy=
array([[[1., 2., 3.]],
       [[4., 5., 6.]]], dtype=float32)>

tf.sequence_mask

官方API

用法:
在处理用户行为序列数据时常与tf.tile组合使用,假设有两个用户小a,小b,他们的历史购买行为分别为[1,2]和[3,4,5],这里的数字代表商品id,通常我们需要将两个行为序列进行对齐,转为[ [1,2,0], [3,4,5] ]这样子的2*3的矩阵。我们定义一个mask=[2,3], 它表示用户小a有两个有效行为,而小b有三个。接下来就该tf.sequence_mask登场了。

#用户历史行为矩阵
hist_behavior=tf.constant([[1,2,0],[3,4,5]],dtype=tf.int32)

#定义mask矩阵
mask=[2,3]
mask=tf.sequence_mask(mask,dtype=tf.float32)

#转化后的mask
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1., 1., 0.],
       [1., 1., 1.]], dtype=float32)>

tf.tile

tensor复制小函数
官方API

用法:
我们将上一个例子中的用户行为序列转为稠密向量表示,如下:

#定义嵌入向量矩阵
emb_mat=tf.constant(tf.random.normal([10,8]),name='embeddig_matrix')

#行为序列向量表示
hist_emb=tf.gather(emb_mat,hist_behavior,axis=0)

这时我们会发现一个问题,hist_emb中小a的第三个行为向量不为0,我们希望将这个行为向量屏蔽,方法如下:

#mask升维
mask=tf.expand_dims(mask,axis=2) #[B,H,T=1] B表示样本数量,H表示行为个数,T表示特征维度

#将mask按第二个维度进行复制
mask=tf.tile(mask,[1,1,8]) #[B,H,T=8],这样就与hist_emb维度相同了。

#屏蔽小a的第三个行为向量
hist_emb*=mask

#最终结果如下:
<tf.Tensor: shape=(2, 3, 8), dtype=float32, numpy=
array([[[-1.5329047 , -0.71585095,  0.00366356, -0.1149451 ,
         -0.56995   ,  0.5137083 ,  0.89719975,  1.1282433 ],
        [-0.12693281, -0.18301146,  0.20593788,  0.9621136 ,
         -1.1635238 ,  0.19840436,  1.474231  ,  1.3114107 ],
        [ 0.        , -0.        ,  0.        ,  0.        ,
         -0.        ,  0.        , -0.        , -0.        ]],
       [[-0.5569906 ,  0.9810741 ,  1.1816455 , -0.6323199 ,
         -0.15034992,  0.92025596,  0.89326286, -0.4486617 ],
        [-0.09791443,  1.206123  ,  0.40974626, -0.23702705,
         -1.4297739 , -0.56800425,  1.4484124 ,  0.58209807],
        [-0.47014222, -0.06272995,  0.7494127 , -0.44260108,
          0.01560141,  0.9623174 , -1.083452  , -0.16848513]]],
      dtype=float32)>

未完~~等我下回继续分解。

上一篇: Centos7 忘记密码

下一篇: 日志管理