keras minist3
程序员文章站
2022-07-13 12:00:35
...
修改了可以使用多显卡训练
发现对于小模型,多显卡也没快起来
另:batchsize大了,训练速度是快了,但是性能狂跌
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential,Model
from keras.layers import Input,Conv2D,Dense,Dropout,Convolution2D,MaxPooling2D,Flatten
from keras.optimizers import Adam
import tensorflow as tf
from keras.utils import multi_gpu_model
import time
(x_train,y_train),(x_test,y_test) = mnist.load_data()
x_train = x_train.reshape(-1,28,28,1)/255.0
x_test = x_test.reshape(-1,28,28,1)/255.0
print(x_train.shape)
print(x_test.shape)
num_classes = 10
with tf.device('/cpu:0'):
input_image = Input(shape=(28,28,1))
cnn = Conv2D(32,(5,5),padding='same',activation='relu')(input_image)
cnn = MaxPooling2D((2,2),padding='same')(cnn)
cnn = Conv2D(64,(5,5),padding='same',activation='relu')(cnn)
cnn = MaxPooling2D((2,2),padding='same')(cnn)
cnn = Flatten()(cnn)
feature = Dense(1024,activation='relu')(cnn)
feature = Dropout(0.5)(feature)
predict = Dense(num_classes,activation='softmax',name='softmax')(feature)
model = Model(inputs=input_image, outputs=predict)
adam = Adam(lr=1e-4)
model.compile(optimizer=adam,loss='sparse_categorical_crossentropy',metrics=['accuracy'])
model.summary()
model_p = multi_gpu_model(model,4)
model_p.compile(optimizer=adam,loss='sparse_categorical_crossentropy',metrics=['accuracy'])
ta = time.time()
model_p.fit(x_train,y_train,batch_size=64,epochs=10)
tb = time.time()
print('train time = ',tb-ta)
loss,accuracy = model.evaluate(x_test,y_test)
print('test loss',loss)
print('test accuracy',accuracy)
model.save("Test.h5")
from keras.models import load_model
model = load_model('Test.h5')
model.summary()
上一篇: jquery按钮禁用(全)
下一篇: [Poi2014]FarmCraft
推荐阅读
-
tensorflow2.0之keras实现卷积神经网络
-
基于keras训练的h5模型进行批量预测
-
【翻译】Keras.NET简介 - 高级神经网络API in C#
-
ubuntu系统theano和keras的安装方法
-
(sklearn:Logistic回归)和(keras:全连接神经网络)完成mnist手写数字分类
-
简单却好用:使用Keras 2实现基于LSTM的多维时间序列预测
-
Anacoda3下安装tensorflow、keras步骤
-
Keras:Input()函数
-
win10(64)+python3.7+Anaconda3+tensorflow-cpu+Keras安装(亲测有效)
-
tensorflow2.0之keras实现卷积神经网络