[Tensorflow] Batch Normalization实现
程序员文章站
2022-07-16 19:13:35
...
bn的优势:
(1)更大的学习率(传统方法太大的learning rate容易导致梯度explode/vanish,或者get stuck in poor local)
(2)不再需要dropout
(3)less careful about initialization
但是BN不仅仅加BN层,还要修改以下的东西才能更快:
(1)learning rate 赋予更大的初值,且下降得更快。(比如将learning rate从0.0015扩大5倍到0.0075,下降快6倍)
(2)Remove Droupout
(3)Reduce L2 weight decay。(比如每次除5)
(4)Remove LRN
(5、6)其他。。看论文
ResNet 有用到BN,其在CIFAR-10网络中参数为:
20,34,44,56层使用:
learning rate=0.1,在32k和48k iterations时/10。
l2 weght decay=1e-4
110层使用:0.01learning rat用于warm up training,直到training error小于80%。
一、tf.nn.batch_normalization
Tensorflow 提供了Batch Normalization的API。但是,这个API很灵活,灵活的后果就是我们需要自己去定义所有的参数。
(比如,提供给此API的Tensor,居然需要我们自己去计算mean和variance)
tf.nn.batch_normalization(
x, #Tensor,对它执行BN操作
mean, #Tensor,一般为x的平均数,float32。
variance, #Tensor,一般为x的方差,float32。
offset, #Tensor,beta值,BN的shift操作。一般初始为0
scale, #Tensor,gamma值,BN的scale操作。一般初始为1
variance_epsilon, #float。小的实数防止除0出现。
name=None
)
"""
返回值(Tensor):
y= (x-mean)/sqrt(variance^2+variance_epsilon)*scale+offset。
但是mean和variance需要自己提前计算,
而tensorflow又提供了另一个API来计算mean和variance。(当然我们也可以自己瞎搞一个)
"""
这个API完全按照论文的思路设计,且更加灵活(比如mean和variance可以设置为其他值而不是x的均值和方差,beta和gamma也是如此)。(见下图):
二、Tensor平均数和方差计算tf.nn.moments
由于上述的API需要手动计算mean和variance,所以就用到了这个API。
tf.nn.moments(
x, #Tensor,要计算mean和variance的变量
axes, #要处理的维度。BN一般就是所有的维度。即[d for d in range(len(x.get_shape())]
shift=None,
name=None,
keep_dims=False
)
三、例子
import tensorflow as tf
sess=tf.Session()
x=tf.constant([[1,5],[10,100]],dtype=tf.float32)
#维度
axes=[d for d in range(len(x.get_shape()))]
#beta gamma参数
beta= tf.get_variable("beta",shape=[],initializer=tf.constant_initializer(0.0))
gamma=tf.get_variable("gamma",shape=[],initializer=tf.constant_initializer(1.0))
sess.run(tf.global_variables_initializer())
#计算mean和variance,并执行BN操作
x_mean,x_variance=tf.nn.moments(x,axes)
y=tf.nn.batch_normalization(x,x_mean,x_variance,beta,gamma,1e-10,"bn")
#查看最终值
y_mean,y_variance=tf.nn.moments(y,axes)
x_val,xm_val,xv_val,y_val,ym_val,yv_val=sess.run([x,x_mean,x_variance,y,y_mean,y_variance])
print("*********执行BN前的Variable x:************")
print("x=%s\n x mean=%s\n x variance=%s" %(x_val,xm_val,xv_val))
print("*********执行BN后的Variable y:************")
print("y=%s \n y mean=%s\n y variance=%s" %(y_val,ym_val,yv_val))
执行结果为:
*********执行BN前的Variable x:************
x=[[ 1. 5.]
[ 10. 100.]]
x mean=29.0
x variance=1690.5
*********执行BN后的Variable y:************
y=[[-0.68100518 -0.58371872]
[-0.46211064 1.72683454]]
y mean=0.0
y variance=1.0
x=[[ 1. 5.]
[ 10. 100.]]
x mean=29.0
x variance=1690.5
*********执行BN后的Variable y:************
y=[[-0.68100518 -0.58371872]
[-0.46211064 1.72683454]]
y mean=0.0
y variance=1.0
可知道x经过BN处理后得到y,y的均值为0,方差变成1了(beta为0,gamma为1时)。
这里我们可以修改下beta和gamma的初始值,则y的平均值会变成beta,方差会变成gamma^2。
四、BN层放置顺序
BN网络中,一个卷积层或全连接层中,对于输入x,有3步中间操作:BN操作、weight操作、ReLu操作。这三种操作的顺序该怎么排列。
原论文的说法是:在Any layer previously received x as input, now received BN(x),但一个卷积层中的子层呢?
对于2*con16 =》 2*conv32=》2*conv64=》fc-10 在MNIST中试了下三种顺序:
(1) x -> bn -> weight -> relu
(2) x -> bn -> relu -> weight
(3) x -> weight ->bn -> relu
最后发现效果都挺好的,可能是这个数据集太简单了,有待以后继续测试。。。
不过在Resnet 1k网络中,第2种方法比第3种效果更好(在有shortcut的情况)。
论文地址:https://arxiv.org/pdf/1603.05027.pdf
五、BN在Mnist效果对比
由于Mnist太简单,正常CNN网络加不加BN层效果不明显。
所以我们需要给网络模型增加训练难度:把ReLu替换成Sigmoid。
(使用Sigmoid会让训练无比的慢,起码慢了百八十倍了~我一开始还以为网络出问题了。。ReLu真的强大!)
其他参数一致:网络为768*100*100*100*10的全连接模型,学习率为1e-4,momentum=0.9,L2_WEGHT_DECAY=1e-4,batch_sz为50,训练10个epoch。
无BN层训练结果:
[step100] accuracy=0.1 loss=116.691
[step200] accuracy=0.16 loss=114.826
[step300] accuracy=0.1 loss=115.051
[step400] accuracy=0.1 loss=117.023
[step500] accuracy=0.08 loss=115.734
[step600] accuracy=0.14 loss=114.13
[step700] accuracy=0.1 loss=115.985
[step800] accuracy=0.1 loss=115.7
[step900] accuracy=0.02 loss=117.614
[step1000] accuracy=0.1 loss=115.558
[*]Test Result=0.0892000000738 at epoch0
[step100] accuracy=0.08 loss=114.817
[step200] accuracy=0.22 loss=113.812
[step300] accuracy=0.1 loss=115.722
[step400] accuracy=0.04 loss=116.21
[step500] accuracy=0.14 loss=115.215
[step600] accuracy=0.08 loss=115.071
[step700] accuracy=0.14 loss=115.076
[step800] accuracy=0.06 loss=116.63
[step900] accuracy=0.12 loss=114.81
[step1000] accuracy=0.08 loss=115.669
[*]Test Result=0.100900000408 at epoch1
[step100] accuracy=0.1 loss=115.425
[step200] accuracy=0.1 loss=115.394
[step300] accuracy=0.08 loss=115.214
[step400] accuracy=0.04 loss=114.856
[step500] accuracy=0.08 loss=117.108
[step600] accuracy=0.14 loss=113.223
[step700] accuracy=0.08 loss=115.142
[step800] accuracy=0.16 loss=114.448
[step900] accuracy=0.1 loss=114.995
[step1000] accuracy=0.18 loss=115.651
[*]Test Result=0.113499999568 at epoch2
[step100] accuracy=0.12 loss=114.254
[step200] accuracy=0.08 loss=116.074
[step300] accuracy=0.2 loss=113.781
[step400] accuracy=0.08 loss=115.302
[step500] accuracy=0.06 loss=115.785
[step600] accuracy=0.08 loss=116.462
[step700] accuracy=0.08 loss=114.897
[step800] accuracy=0.14 loss=116.592
[step900] accuracy=0.1 loss=116.425
[step1000] accuracy=0.06 loss=114.058
[*]Test Result=0.103200000077 at epoch3
[step100] accuracy=0.26 loss=113.873
[step200] accuracy=0.08 loss=115.774
[step300] accuracy=0.14 loss=114.722
[step400] accuracy=0.1 loss=114.43
[step500] accuracy=0.12 loss=114.766
[step600] accuracy=0.08 loss=116.453
[step700] accuracy=0.02 loss=116.828
[step800] accuracy=0.06 loss=115.831
[step900] accuracy=0.14 loss=114.576
[step1000] accuracy=0.04 loss=114.588
[*]Test Result=0.113499999568 at epoch4
[step100] accuracy=0.24 loss=114.013
[step200] accuracy=0.1 loss=115.269
[step300] accuracy=0.08 loss=115.71
[step400] accuracy=0.18 loss=113.4
[step500] accuracy=0.14 loss=115.153
[step600] accuracy=0.08 loss=114.52
[step700] accuracy=0.12 loss=114.871
[step800] accuracy=0.22 loss=115.017
[step900] accuracy=0.12 loss=113.872
[step1000] accuracy=0.12 loss=115.084
[*]Test Result=0.171800000742 at epoch5
[step100] accuracy=0.12 loss=116.787
[step200] accuracy=0.1 loss=116.283
[step300] accuracy=0.04 loss=115.422
[step400] accuracy=0.14 loss=114.826
[step500] accuracy=0.18 loss=114.08
[step600] accuracy=0.14 loss=114.935
[step700] accuracy=0.18 loss=114.367
[step800] accuracy=0.02 loss=115.996
[step900] accuracy=0.08 loss=114.403
[step1000] accuracy=0.24 loss=113.339
[*]Test Result=0.113499999568 at epoch6
[step100] accuracy=0.18 loss=114.502
[step200] accuracy=0.12 loss=114.226
[step300] accuracy=0.14 loss=114.238
[step400] accuracy=0.28 loss=113.135
[step500] accuracy=0.04 loss=115.067
[step600] accuracy=0.16 loss=113.927
[step700] accuracy=0.1 loss=113.124
[step800] accuracy=0.06 loss=114.841
[step900] accuracy=0.16 loss=113.212
[step1000] accuracy=0.26 loss=112.934
[*]Test Result=0.199200000018 at epoch7
[step100] accuracy=0.16 loss=114.148
[step200] accuracy=0.12 loss=113.84
[step300] accuracy=0.14 loss=112.673
[step400] accuracy=0.2 loss=112.878
[step500] accuracy=0.2 loss=114.386
[step600] accuracy=0.12 loss=112.982
[step700] accuracy=0.38 loss=111.301
[step800] accuracy=0.3 loss=112.395
[step900] accuracy=0.52 loss=110.003
[step1000] accuracy=0.12 loss=111.22
[*]Test Result=0.122199999765 at epoch8
[step100] accuracy=0.08 loss=112.523
[step200] accuracy=0.42 loss=108.418
[step300] accuracy=0.4 loss=105.239
[step400] accuracy=0.5 loss=98.153
[step500] accuracy=0.22 loss=103.485
[step600] accuracy=0.2 loss=104.636
[step700] accuracy=0.48 loss=95.7585
[step800] accuracy=0.24 loss=94.8633
[step900] accuracy=0.38 loss=93.5662
[step1000] accuracy=0.36 loss=89.0528
[*]Test Result=0.351300003231 at epoch9
跑了10个epoch,测试集正确率才到达35%。
加了BN层以后训练效果:
[step100] accuracy=0.1 loss=116.102
[step200] accuracy=0.22 loss=112.854
[step300] accuracy=0.14 loss=115.377
[step400] accuracy=0.1 loss=115.649
[step500] accuracy=0.1 loss=115.625
[step600] accuracy=0.24 loss=114.879
[step700] accuracy=0.1 loss=115.61
[step800] accuracy=0.12 loss=114.699
[step900] accuracy=0.14 loss=115.097
[step1000] accuracy=0.1 loss=114.932
[*]Test Result=0.0974000002816 at epoch0
[step100] accuracy=0.1 loss=116.12
[step200] accuracy=0.06 loss=116.164
[step300] accuracy=0.1 loss=115.818
[step400] accuracy=0.12 loss=115.697
[step500] accuracy=0.18 loss=115.264
[step600] accuracy=0.2 loss=114.414
[step700] accuracy=0.04 loss=115.895
[step800] accuracy=0.12 loss=114.564
[step900] accuracy=0.06 loss=115.524
[step1000] accuracy=0.22 loss=114.622
[*]Test Result=0.161500000656 at epoch1
[step100] accuracy=0.2 loss=115.315
[step200] accuracy=0.14 loss=114.43
[step300] accuracy=0.1 loss=115.918
[step400] accuracy=0.16 loss=114.786
[step500] accuracy=0.26 loss=112.941
[step600] accuracy=0.3 loss=113.985
[step700] accuracy=0.3 loss=112.463
[step800] accuracy=0.14 loss=113.471
[step900] accuracy=0.14 loss=112.914
[step1000] accuracy=0.14 loss=112.23
[*]Test Result=0.24730000034 at epoch2
[step100] accuracy=0.2 loss=111.719
[step200] accuracy=0.32 loss=108.348
[step300] accuracy=0.24 loss=106.837
[step400] accuracy=0.36 loss=102.211
[step500] accuracy=0.32 loss=99.1392
[step600] accuracy=0.42 loss=94.0066
[step700] accuracy=0.5 loss=82.9231
[step800] accuracy=0.5 loss=78.0428
[step900] accuracy=0.56 loss=75.0709
[step1000] accuracy=0.56 loss=72.2615
[*]Test Result=0.569599996507 at epoch3
[step100] accuracy=0.54 loss=72.2187
[step200] accuracy=0.62 loss=62.6503
[step300] accuracy=0.7 loss=51.1989
[step400] accuracy=0.7 loss=50.0574
[step500] accuracy=0.62 loss=48.4715
[step600] accuracy=0.58 loss=56.4319
[step700] accuracy=0.76 loss=48.7727
[step800] accuracy=0.76 loss=39.0827
[step900] accuracy=0.66 loss=44.0735
[step1000] accuracy=0.74 loss=40.6393
[*]Test Result=0.731999999881 at epoch4
[step100] accuracy=0.82 loss=39.1621
[step200] accuracy=0.8 loss=33.0594
[step300] accuracy=0.68 loss=41.5027
[step400] accuracy=0.72 loss=49.6565
[step500] accuracy=0.8 loss=32.1081
[step600] accuracy=0.8 loss=42.5631
[step700] accuracy=0.84 loss=31.7484
[step800] accuracy=0.8 loss=34.406
[step900] accuracy=0.7 loss=36.0701
[step1000] accuracy=0.76 loss=39.4207
[*]Test Result=0.798400003314 at epoch5
[step100] accuracy=0.66 loss=38.2423
[step200] accuracy=0.88 loss=23.5632
[step300] accuracy=0.8 loss=37.7658
[step400] accuracy=0.8 loss=41.1382
[step500] accuracy=0.84 loss=31.7916
[step600] accuracy=0.86 loss=24.6395
[step700] accuracy=0.8 loss=29.7371
[step800] accuracy=0.84 loss=33.4366
[step900] accuracy=0.84 loss=25.56
[step1000] accuracy=0.92 loss=23.0958
[*]Test Result=0.841499999762 at epoch6
[step100] accuracy=0.9 loss=17.4944
[step200] accuracy=0.74 loss=35.0277
[step300] accuracy=0.9 loss=30.2663
[step400] accuracy=0.78 loss=34.679
[step500] accuracy=0.82 loss=25.4055
[step600] accuracy=0.86 loss=19.0345
[step700] accuracy=0.98 loss=14.34
[step800] accuracy=0.86 loss=27.425
[step900] accuracy=0.78 loss=35.237
[step1000] accuracy=0.88 loss=23.2125
[*]Test Result=0.86880000174 at epoch7
[step100] accuracy=0.88 loss=23.3765
[step200] accuracy=0.82 loss=33.0606
[step300] accuracy=0.76 loss=44.3354
[step400] accuracy=0.9 loss=17.5737
[step500] accuracy=0.82 loss=27.3082
[step600] accuracy=0.92 loss=18.8941
[step700] accuracy=0.84 loss=27.9557
[step800] accuracy=0.9 loss=16.8646
[step900] accuracy=0.92 loss=12.3513
[step1000] accuracy=0.9 loss=22.4553
[*]Test Result=0.886900002956 at epoch8
跑了9个epoch差不多有了88%的正确率。粗略估计下同样到达35%正确率,前者需要10个epochs,后者差不多需要3.4个epochs。
快了3倍左右~这个数值和论文上BN-Baseline与Incetion的加速差不多。应该可以通过调整LR变得更快。
代码:
###VGG.PY#########
import tensorflow as tf
"""
(1)构造函数__init__参数
input_sz: 输入层placeholder的4-D shape,如mnist是[None,28,28,1]
fc_layers: 全连接层每一层大小,接在卷积层后面。如mnist可以为[128,84,10],[10]
conv_info: 卷积层、池化层。
如vgg16可以这样写:[(2,64),(2,128),(3,256),(3,512),(3,512)],表示2+2+3+3+3=13个卷积层,4个池化层,以及channels
(2)train函数:训练一步
batch_input: 输入的batch
batch_output: label
learning_rate:学习率
返回:正确率和loss值(float) 格式:{"accuracy":accuracy,"loss":loss}
(3)forward:训练后用于测试
(4)save(save_path,steps)保存模型
(5)restore(path):从文件夹中读取最后一个模型
(6)loss函数使用cross-entrop one-hot版本:y*log(y_net)
(7)optimizer使用adamoptimier
"""
class VGG: #VGG分类器
sess=None
#Tensor
input=None
output=None
desired_out=None
loss=None
iscorrect=None
accuracy=None
optimizer=None
param_num=0 #参数个数
#参数
learning_rate=None
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-4 #L2 REGULARIZATION
ACTIVATE = None
CONV_PADDING = "SAME"
MAX_POOL_PADDING = "SAME"
CONV_WEIGHT_INITAILIZER = tf.truncated_normal_initializer(stddev=0.1)
CONV_BIAS_INITAILIZER = tf.constant_initializer(value=0.0)
FC_WEIGHT_INITAILIZER = tf.truncated_normal_initializer(stddev=0.1)
FC_BIAS_INITAILIZER = tf.constant_initializer(value=0.0)
def train(self,batch_input,batch_output,learning_rate):
_,accuracy,loss=self.sess.run([self.optimizer,self.accuracy,self.loss],
feed_dict={self.input:batch_input,self.desired_out:batch_output,self.learning_rate:learning_rate})
return {"accuracy":accuracy,"loss":loss}
def forward(self,batch_input):
return self.sess.run(self.output,feed_dict={self.input:batch_input})
def save(self,save_path,steps):
saver=tf.train.Saver(max_to_keep=5)
saver.save(self.sess,save_path,global_step=steps)
def restore(self,restore_path):
path=tf.train.latest_checkpoint(restore_path)
print("[*]Restore from %s" %(path))
if path==None:
return False
saver=tf.train.Saver(max_to_keep=5)
saver.restore(self.sess,path)
return True
def bn(self,x,name="bn"):
#return x
axes = [d for d in range(len(x.get_shape()))]
beta = self._get_variable("beta", shape=[],initializer=tf.constant_initializer(0.0))
gamma= self._get_variable("gamma",shape=[],initializer=tf.constant_initializer(1.0))
x_mean,x_variance=tf.nn.moments(x,axes)
y=tf.nn.batch_normalization(x,x_mean,x_variance,beta,gamma,1e-10,name)
return y
def get_optimizer(self): #
#Optimizer
#sself.optimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(self.loss)
#self.optimizer =tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss) #1300 steps后达到误差范围。
self.optimizer =tf.train.MomentumOptimizer(self.learning_rate,self.MOMENTUM).minimize(self.loss) #9000 steps后达到误差范围。
#对x执行一次卷积操作+Relu
def conv(self,x,name,channels,ksize=3):
x_shape=x.get_shape()
x_channels=x_shape[3].value
weight_shape=[ksize,ksize,x_channels,channels]
bias_shape=[channels]
weight = self._get_variable("weight",weight_shape,initializer=self.CONV_WEIGHT_INITAILIZER)
bias = self._get_variable("bias",bias_shape,initializer=self.CONV_BIAS_INITAILIZER)
y=tf.nn.conv2d(x,weight,strides=[1,1,1,1],padding=self.CONV_PADDING,name=name)
y=tf.add(y,bias,name=name)
return y
def max_pool(self,x,name):
return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding=self.MAX_POOL_PADDING,name=name)
#定义_get_variable方便进行l2_regularization以及其他一些操作
def _get_variable(self,name,shape,initializer):
param=1
for i in range(0,len(shape)):
param*=shape[i]
self.param_num+=param
if self.WEIGHT_DECAY>0:
regularizer=tf.contrib.layers.l2_regularizer(self.WEIGHT_DECAY)
else:
regularizer=None
return tf.get_variable(name,
shape=shape,
initializer=initializer,
regularizer=regularizer)
def fc(self,x,num,name):
x_num=x.get_shape()[1].value
weight_shape=[x_num,num]
bias_shape =[num]
weight=self._get_variable("weight",shape=weight_shape,initializer=self.FC_WEIGHT_INITAILIZER)
bias =self._get_variable("bias",shape=bias_shape,initializer=self.FC_BIAS_INITAILIZER)
y=tf.add(tf.matmul(x,weight),bias,name=name)
return y
def _loss(self):
cross_entropy=-tf.reduce_sum(self.desired_out*tf.log(tf.clip_by_value(self.output,1e-10,1.0)))
regularization_losses=tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
self.loss = tf.add_n([cross_entropy]+regularization_losses)
#tf.scalar_summary('loss', loss_)
return self.loss
def __init__(self,input_sz,fc_layers,conv_info=[],activate_fun=tf.nn.relu): #
self.ACTIVATE=activate_fun
self.param_num=0 #返回参数个数
self.sess=tf.Session()
layers=[]
#(1)placeholder定义(输入、输出、learning_rate)
#input
self.input=tf.placeholder(tf.float32,input_sz,name="input")
layers.append(self.input)
#
layers.append(self.bn(layers[-1]))
#output
output_sz=[None,fc_layers[-1]]
self.desired_out=tf.placeholder(tf.float32,output_sz,name="desired_out")
self.learning_rate=tf.placeholder(tf.float32,name="learning_rate")
#(2)插入卷积层+池化层
with tf.variable_scope("convolution"):
conv_block_id=0
for cur_layers in conv_info:
#添加卷积层block
with tf.variable_scope("conv_block_%d" %(conv_block_id)) as scope:
cur_conv_num=cur_layers[0] #cur_conv_num个卷积层叠放
cur_channels=cur_layers[1] #每个卷积层的通道
#cur_conv_num个卷积层叠加
for conv_id in range(0,cur_conv_num):
with tf.variable_scope("conv_%d" %(conv_id)):
#添加卷积层
x=layers[-1]
"""
#顺序一:x->bn->weight->relu
x2=self.bn(x)
x3=self.conv(x2,channels=cur_channels,name="conv")
x4=self.ACTIVATE(x3)
"""
#"""
#顺序二: x->bn->relu->weight
x2=self.bn(x)
x3=self.ACTIVATE(x2)
x4=self.conv(x3,channels=cur_channels,name="conv")
#"""
"""
#顺序三:x->weight->bn->relu
x2=self.conv(x,channels=cur_channels,name="conv")
x3=self.bn(x2)
x4=self.ACTIVATE(x3)
"""
layers.append(x4)
#每个卷积块后是pool层
last_layer=layers[-1]
pool=self.max_pool(last_layer,"max_pool")
layers.append(pool)
conv_block_id+=1
#(3)卷积层flatten
last_layer=layers[-1]
last_shape=last_layer.get_shape()
neu_num=1
for dim in range(1,len(last_shape)):
neu_num*=last_shape[dim].value
flat_layer=tf.reshape(last_layer,[-1,neu_num],name="flatten")
layers.append(flat_layer)
#(4)全连接层 #!!!!!!!!!最后一层不要加上relu!!!!!!
with tf.variable_scope("full_connection"):
for fc_id in range(0,len(fc_layers)):
with tf.variable_scope("fc_%d" %(fc_id)):
num=fc_layers[fc_id]
x=layers[-1]
x2=self.bn(x)
x3=self.ACTIVATE(x,name="relu")
y=self.fc(x3,num,"fc")
layers.append(y)
#(5)softmax和loss函数
self.output=tf.nn.softmax(layers[-1])
#loss函数
self._loss()
#(6)辅助信息:正确率
self.iscorrect=tf.equal(tf.argmax(self.desired_out,1),tf.argmax(self.output,1),name="iscorrect")
self.accuracy=tf.reduce_mean(tf.cast(self.iscorrect,dtype=tf.float32),name="accuracy")
#(7)优化器和 variables初始化
self.get_optimizer()
self.sess.run(tf.global_variables_initializer())
writer = tf.summary.FileWriter("./tboard/",self.sess.graph)
def __del__(self):
self.sess.close()
####VGG_MNIST.PY####
import VGG
import tensorflow as tf
import sys
from tensorflow.examples.tutorials.mnist import input_data
vgg=VGG.VGG([None,28,28,1],[100,100,100,10],activate_fun=tf.sigmoid)#,[(3,16),(3,32),(3,64),(3,128)])
#vgg=VGG.VGG([None,28,28,1],[10],[(2,16),(2,32),(2,64)])
print("param_num=%d" %(vgg.param_num))
#writer = tf.summary.FileWriter("./tboard/",vgg.sess.graph)
mnist = input_data.read_data_sets("input_data", one_hot=True)
def get_mnist_batch(num,get_test=False):
batch=None
if get_test:
batch=[mnist.test.images,mnist.test.labels]
else:
batch=mnist.train.next_batch(num)
input=[]
for x in batch[0]:
inp=[[0 for _ in range(0,28)] for _ in range(0,28)]
for row in range(0,28):
for col in range(0,28):
inp[row][col]=[x[row*28+col]]
"""
if inp[row][col][0]>0.6:
print(" ",end="")
else:
if inp[row][col][0]>0.3:
print(".",end="")
else:
print("w",end="")
if col==27:
print("")
sys.exit(0)
"""
input.append(inp)
return input,batch[1]
def get_mnist_test_accuracy():
batch=get_mnist_batch(0,True)
accuracy=0
for st in range(0,10000,100):
ret=vgg.train(batch[0][st:st+100],batch[1][st:st+100],learning_rate=0)
accuracy+=ret["accuracy"]/100
return accuracy
"""
if vgg.restore("./model/"):
test_acc=get_mnist_test_accuracy()
print("[*]Test Result=%s at epoch%d" %(test_acc,0))
"""
learning_rate=1e-4
for epoch in range(0,10):
batch_sz=50
for i in range(int(50000/batch_sz)):
batch = get_mnist_batch(batch_sz)
ret=vgg.train(batch[0],batch[1],learning_rate=learning_rate)
if i%100==0:
#print(batch[1][0])
#print(ret[0][0])
print("[step%d] accuracy=%s loss=%s" %(i+100,ret["accuracy"],ret["loss"]))
#learning_rate/=2
vgg.save("model/mnist_epoch",epoch)
test_acc=get_mnist_test_accuracy()
print("[*]Test Result=%s at epoch%d" %(test_acc,epoch))
上一篇: C#窗体Winform,如何嵌入图片添加图片,使用图片资源?
下一篇: [C#中级] 事件