欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

蒙特卡洛dropout

程序员文章站 2022-07-11 09:09:40
...

链接是这个https://blog.csdn.net/weixin_26731327/article/details/109070481?utm_medium=distribute.pc_relevant.none-task-blog-title-2&spm=1001.2101.3001.4242
自我总结先写前面,我认为蒙特卡洛dropout首先肯定是测试的时候也开着dropout,然后就是测试n次的测试集,然后求n次的输出概率的平均值,得到不确定性,以此再取沿轴的最大值;
普通的softmax在测试时候的dropout是关闭着的。就是直接一个测试就一个输出,然后直接取沿轴最大值当作输出、

以下就是常规的分类代码,测试时,dropout在普通情况下是关闭的,但是在蒙特卡咯情况下是开启的

(X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data()
 
 
model = keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=(28, 28)))
model.add(keras.layers.Dropout(0.25))
model.add(keras.layers.Dense(300, activation="relu"))
model.add(keras.layers.Dropout(0.25))
model.add(keras.layers.Dense(300, activation="relu"))
model.add(keras.layers.Dropout(0.25))
model.add(keras.layers.Dense(10, activation="softmax"))
 
 
optimizer = keras.optimizers.Nadam(lr=0.001)
model.compile(loss="sparse_categorical_crossentropy", 
              optimizer=optimizer, metrics=["accuracy"])
model.fit(X_train, y_train, epochs=50)
model.evaluate(X_test, y_test)

模型准确性的计算:

可以生成任意数量的预测,就是说可以预测任意多次

def predict_proba(X, model, num_samples):
    preds = [model(X, training=True) for _ in range(num_samples)]
    return np.stack(preds).mean(axis=0)
     
def predict_class(X, model, num_samples):
    proba_preds = predict_proba(X, model, num_samples)
    return np.argmax(proba_preds, axis=1)
y_pred = predict_class(X_test, model, 100)
acc = np.mean(y_pred == y_test)

由以上代码可以看出,先弄num_samples次预测,然后取平均值,然后再沿着某一轴取最大值,即可得到比原来好的预测效果。

预测不确定性

y_pred_proba = predict_proba(X_test, model, 100)
 
 
softmax_output = np.round(model.predict(X_test[1:2]), 3)
mc_pred_proba = np.round(y_pred_proba[1], 3)
print(softmax_output, mc_pred_proba)
softmax_output: [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
mc_pred_proba: [0. 0. 0.989 0.008 0.001 0. 0. 0.001 0.001 0. ]

softmax_output: [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 1. 0. 0. 0. 0. 0. 0. 0.] 
mc_pred_proba: [0. 0. 0.989 0.008 0.001 0. 0. 0.001 0.001 0. ] [0. 0. 0.989 0.008 0.001 0. 0. 0.001 0.001 0. ]