欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Tensorflow框架 -- Dropout的用法

程序员文章站 2022-07-13 10:37:58
...

1、Dropout的作用

Dropout通过随机失活神经元,缓解网络过拟合,起到正则化作用。

2、相关函数

def tf.nn.dropout(x, keep_prob, noise_shape=None, seed=None, name=None)
  # x: 该层的输入层;
  # keep_prob: 保留输入元素的比例,同时将保留的值缩放为1/keep_prob;
  # noise_shape: None;
  # seed: None;
  # name: 该层的名字;
def tf.layer.dropout(inputs, 
                     rate=0.5, 
                     noise_shape=None, 
                     seed=None,
                     training=False,
                     name=None)
  # inputs: 该层的输入元素;
  # rate: 输入元素被随机丢弃的比例
  # noise_shape: None
  # seed: None
  # training: 当值为True时,训练时随机丢弃rate比例的输入,当预测时,设置为False
  # name: 该层的名字

3、tf.nn.dropout与tf.layer.dropout的区别

共同点:两个函数的共同点都是将输入元素随机丢弃一部分;

异同点:

  1. tf.nn.dropout 的参数keep_prob,表示保留的元素比例,而tf.layer.dropout的参数rate,表示丢弃的元素比例,所以训练过程中,经过dropout后,保留的元素比例:keep_prob = 1- rate;
  2. tf.layer.dropout的参数training,表示训练时作丢弃处理,但是预测时,需要设置为False;而tf.nn.dropout没有training参数,通常将keep_prob作为参数传入,详细见后续代码片段;

4、实际使用

首先,由于dropout随机丢弃输入的部分元素,那么如何保证输入元素的总和不变?需要对dropout的元素进行缩放,缩放的比例为1/keep_prob(tf.nn.dropout),或者1/(1-rate)(tf.layer.dropout),具体见如下代码:

import tensorflow as tf

keep_prob = tf.placeholder(tf.float32)
rate = tf.placeholedr(tf.float32)

x = tf.Variable(tf.ones([5, 5]))
y = tf.nn.dropout(x, keep_prob)
z = tf.layers.dropout(x, rate, training=True)

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
y = sess.run(y, feed_dict={keep_prob: 0.2})
z = sess.run(z, feed_dict={rate: 0.2})
print(y)
print(z)

# 训练阶段,设置training=True, keep_prob=0.2,输出如下:
[[ 0.  0.  0.  0.  5.]
 [ 0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.]
 [ 5.  0.  0.  0.  0.]
 [ 0.  0.  0.  5.  5.]]
[[ 0.    0.    1.25  0.    1.25]
 [ 1.25  1.25  1.25  1.25  1.25]
 [ 1.25  0.    1.25  1.25  1.25]
 [ 1.25  1.25  1.25  1.25  1.25]
 [ 1.25  1.25  1.25  1.25  0.  ]]

# 测试阶段,设置training=False,keep_prob=1.0,输出如下:
[[ 1.  1.  1.  1.  1.]
 [ 1.  1.  1.  1.  1.]
 [ 1.  1.  1.  1.  1.]
 [ 1.  1.  1.  1.  1.]
 [ 1.  1.  1.  1.  1.]]
[[ 1.  1.  1.  1.  1.]
 [ 1.  1.  1.  1.  1.]
 [ 1.  1.  1.  1.  1.]
 [ 1.  1.  1.  1.  1.]
 [ 1.  1.  1.  1.  1.]]