PyTorch学习笔记(15)权值初始化
程序员文章站
2022-07-12 23:10:26
...
权值初始化
梯度消失 梯度爆炸
两个相互独立的随机变量乘积的期望 等于 他们各自期望的乘积
方差的公式
两个相互独立的随机变量之和的方差 等于 他们各自方差的和
由1.2.3.式可得
若
Xavier初始化
方差一致性:保持数据尺度维持在恰当范围,通常方差为1
**函数:饱和函数 如Sigmoid Tanh
Kaiming初始化
方差一致性:保持数据尺度维持在恰当范围,通常方差为1
**函数:ReLU及其变种
# -*- coding: utf-8 -*-
import os
import torch
import random
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1) # 设置随机种子
class MLP(nn.Module):
# 构造100层的线性叠加 不考虑偏置
def __init__(self, neural_num, layers):
super(MLP, self).__init__()
self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
self.neural_num = neural_num
def forward(self, x):
for (i, linear) in enumerate(self.linears):
x = linear(x)
x = torch.relu(x)
print("layer:{}, std:{}".format(i, x.std()))
if torch.isnan(x.std()):
print("output is nan in {} layers".format(i))
break
return x
def initialize(self):
for m in self.modules():
if isinstance(m, nn.Linear):
# 判断是否是线性层 若是 对权值进行初始化 采用标准正态分布,0均值 1标准差 的分布
nn.init.normal_(m.weight.data)
# nn.init.normal_(m.weight.data, std=np.sqrt(1/self.neural_num)) # normal: mean=0, std=1
# a = np.sqrt(6 / (self.neural_num + self.neural_num))
#
# tanh_gain = nn.init.calculate_gain('tanh')
# a *= tanh_gain
#
# nn.init.uniform_(m.weight.data, -a, a)
# nn.init.xavier_uniform_(m.weight.data, gain=tanh_gain)
# nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num))
# nn.init.kaiming_normal_(m.weight.data)
# flag = 0
flag = 1
if flag:
layer_nums = 100
neural_nums = 256
batch_size = 16
net = MLP(neural_nums, layer_nums)
net.initialize()
inputs = torch.randn((batch_size, neural_nums)) # normal: mean=0, std=1
output = net(inputs)
print(output)
# ======================================= calculate gain =======================================
flag = 0
# flag = 1
if flag:
x = torch.randn(10000)
out = torch.tanh(x)
gain = x.std() / out.std()
print('gain:{}'.format(gain))
tanh_gain = nn.init.calculate_gain('tanh')
print('tanh_gain in PyTorch:', tanh_gain)
十种初始化方法
Xavier均匀分布
Xavier标准正态分布
Kaiming均匀分布
Kaiming标准正态分布
均匀分布
正态分布
常数分布
正交矩阵初始化
单位矩阵初始化
稀疏矩阵初始化
nn.init.calculate_gain
主要功能:计算**函数的方差变化尺度
主要参数
nonlinearity **函数名称
param **函数的参数 如Leaky ReLU的negative_slop
推荐阅读