Tensorflow框架 -- Dropout的用法
程序员文章站
2022-07-13 10:37:58
...
1、Dropout的作用
Dropout通过随机失活神经元,缓解网络过拟合,起到正则化作用。
2、相关函数
def tf.nn.dropout(x, keep_prob, noise_shape=None, seed=None, name=None)
# x: 该层的输入层;
# keep_prob: 保留输入元素的比例,同时将保留的值缩放为1/keep_prob;
# noise_shape: None;
# seed: None;
# name: 该层的名字;
def tf.layer.dropout(inputs,
rate=0.5,
noise_shape=None,
seed=None,
training=False,
name=None)
# inputs: 该层的输入元素;
# rate: 输入元素被随机丢弃的比例
# noise_shape: None
# seed: None
# training: 当值为True时,训练时随机丢弃rate比例的输入,当预测时,设置为False
# name: 该层的名字
3、tf.nn.dropout与tf.layer.dropout的区别
共同点:两个函数的共同点都是将输入元素随机丢弃一部分;
异同点:
- tf.nn.dropout 的参数keep_prob,表示保留的元素比例,而tf.layer.dropout的参数rate,表示丢弃的元素比例,所以训练过程中,经过dropout后,保留的元素比例:keep_prob = 1- rate;
- tf.layer.dropout的参数training,表示训练时作丢弃处理,但是预测时,需要设置为False;而tf.nn.dropout没有training参数,通常将keep_prob作为参数传入,详细见后续代码片段;
4、实际使用
首先,由于dropout随机丢弃输入的部分元素,那么如何保证输入元素的总和不变?需要对dropout的元素进行缩放,缩放的比例为1/keep_prob(tf.nn.dropout),或者1/(1-rate)(tf.layer.dropout),具体见如下代码:
import tensorflow as tf
keep_prob = tf.placeholder(tf.float32)
rate = tf.placeholedr(tf.float32)
x = tf.Variable(tf.ones([5, 5]))
y = tf.nn.dropout(x, keep_prob)
z = tf.layers.dropout(x, rate, training=True)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
y = sess.run(y, feed_dict={keep_prob: 0.2})
z = sess.run(z, feed_dict={rate: 0.2})
print(y)
print(z)
# 训练阶段,设置training=True, keep_prob=0.2,输出如下:
[[ 0. 0. 0. 0. 5.]
[ 0. 0. 0. 0. 0.]
[ 0. 0. 0. 0. 0.]
[ 5. 0. 0. 0. 0.]
[ 0. 0. 0. 5. 5.]]
[[ 0. 0. 1.25 0. 1.25]
[ 1.25 1.25 1.25 1.25 1.25]
[ 1.25 0. 1.25 1.25 1.25]
[ 1.25 1.25 1.25 1.25 1.25]
[ 1.25 1.25 1.25 1.25 0. ]]
# 测试阶段,设置training=False,keep_prob=1.0,输出如下:
[[ 1. 1. 1. 1. 1.]
[ 1. 1. 1. 1. 1.]
[ 1. 1. 1. 1. 1.]
[ 1. 1. 1. 1. 1.]
[ 1. 1. 1. 1. 1.]]
[[ 1. 1. 1. 1. 1.]
[ 1. 1. 1. 1. 1.]
[ 1. 1. 1. 1. 1.]
[ 1. 1. 1. 1. 1.]
[ 1. 1. 1. 1. 1.]]
上一篇: Pytorch实现L1与L2正则化
下一篇: TensorFlow Dropout
推荐阅读
-
java Swing JFrame框架类中setDefaultCloseOperation的参数含义与用法示例
-
简介PHP的Yii框架中缓存的一些高级用法
-
Zend Framework教程之MVC框架的Controller用法分析
-
Flask框架中密码的加盐哈希加密和验证功能的用法详解
-
iOS App开发中Masonry布局框架的基本用法解析
-
从源码解析Python的Flask框架中request对象的用法
-
实例解析Python的Twisted框架中Deferred对象的用法
-
详解Python的Twisted框架中reactor事件管理器的用法
-
Python(TensorFlow框架)实现手写数字识别系统的方法
-
laravel框架中间件 except 和 only 的用法示例