tensorflow常用模型开发小函数总结(持续补充中)
程序员文章站
2022-05-11 11:10:18
...
使用版本:tensorflow2.2
tf.gather
tf.gather(
params, indices, validate_indices=None, axis=None, batch_dims=0, name=None
)
官方API
用法:
在查询高维稀疏特征的对应嵌入向量值时,非常好用。高维稀疏特征就是类似于商品id这种有非常多离散取值的特征字段,通常建模时会先对它进行编码,转化为0,1,2,3…这样子,然后就可以使用tf.gather获取嵌入向量了。
#定义特征1的嵌入矩阵
emb_mat1=tf.constant(tf.random.normal([5,3]),name='embeddig_matrix1')
#定义特征2的嵌入矩阵
emb_mat2=tf.constant(tf.random.normal([8,3]),name='embeddig_matrix2')
#当样本为(2,4),即特征1取值为2,特征2取值为4时,计算嵌入向量表示
sample=(2,4)
emb1=tf.gather(emb_mat1,[sample[0]],axis=0)
emb2=tf.gather(emb_mat2,[sample[1]],axis=0)
#将嵌入向量进行连接,作为DNN或RNN的输入
vec=tf.concat([emb1,emb2],axis=1)
tf.concat
官方API
用法:
基本上做任何模型都需要做向量拼接,绝对的高频使用函数。
例子可以参考上面那个。
tf.expand_dims
官方API
用法:
给tensor做维度扩展,也是非常常用的小函数。
a=tf.constant([[1,2,3],[4,5,6]],dtype=tf.float32)
tf.expand_dims(a,axis=1)
#扩展结果,shape由2*3变为2*1*3了
<tf.Tensor: shape=(2, 1, 3), dtype=float32, numpy=
array([[[1., 2., 3.]],
[[4., 5., 6.]]], dtype=float32)>
tf.sequence_mask
用法:
在处理用户行为序列数据时常与tf.tile组合使用,假设有两个用户小a,小b,他们的历史购买行为分别为[1,2]和[3,4,5],这里的数字代表商品id,通常我们需要将两个行为序列进行对齐,转为[ [1,2,0], [3,4,5] ]这样子的2*3的矩阵。我们定义一个mask=[2,3], 它表示用户小a有两个有效行为,而小b有三个。接下来就该tf.sequence_mask登场了。
#用户历史行为矩阵
hist_behavior=tf.constant([[1,2,0],[3,4,5]],dtype=tf.int32)
#定义mask矩阵
mask=[2,3]
mask=tf.sequence_mask(mask,dtype=tf.float32)
#转化后的mask
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1., 1., 0.],
[1., 1., 1.]], dtype=float32)>
tf.tile
tensor复制小函数
官方API
用法:
我们将上一个例子中的用户行为序列转为稠密向量表示,如下:
#定义嵌入向量矩阵
emb_mat=tf.constant(tf.random.normal([10,8]),name='embeddig_matrix')
#行为序列向量表示
hist_emb=tf.gather(emb_mat,hist_behavior,axis=0)
这时我们会发现一个问题,hist_emb中小a的第三个行为向量不为0,我们希望将这个行为向量屏蔽,方法如下:
#mask升维
mask=tf.expand_dims(mask,axis=2) #[B,H,T=1] B表示样本数量,H表示行为个数,T表示特征维度
#将mask按第二个维度进行复制
mask=tf.tile(mask,[1,1,8]) #[B,H,T=8],这样就与hist_emb维度相同了。
#屏蔽小a的第三个行为向量
hist_emb*=mask
#最终结果如下:
<tf.Tensor: shape=(2, 3, 8), dtype=float32, numpy=
array([[[-1.5329047 , -0.71585095, 0.00366356, -0.1149451 ,
-0.56995 , 0.5137083 , 0.89719975, 1.1282433 ],
[-0.12693281, -0.18301146, 0.20593788, 0.9621136 ,
-1.1635238 , 0.19840436, 1.474231 , 1.3114107 ],
[ 0. , -0. , 0. , 0. ,
-0. , 0. , -0. , -0. ]],
[[-0.5569906 , 0.9810741 , 1.1816455 , -0.6323199 ,
-0.15034992, 0.92025596, 0.89326286, -0.4486617 ],
[-0.09791443, 1.206123 , 0.40974626, -0.23702705,
-1.4297739 , -0.56800425, 1.4484124 , 0.58209807],
[-0.47014222, -0.06272995, 0.7494127 , -0.44260108,
0.01560141, 0.9623174 , -1.083452 , -0.16848513]]],
dtype=float32)>
未完~~等我下回继续分解。
上一篇: Centos7 忘记密码
下一篇: 日志管理