欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Keras入门 mnist训练预测

程序员文章站 2022-05-19 08:56:34
...

Keras入门 mnist训练+预测代码

from __future__ import print_function
from tensorflow.python.keras.models import Sequential, load_model
from tensorflow.python.keras.layers import Dense, Dropout
from tensorflow.python.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.python.keras.datasets import mnist
from tensorflow.python.keras.utils import to_categorical
import os


def build_model(model_name):
    if os.path.exists(model_name):
        print("Loading existing model.")
        model = load_model(model_name)
    else:
        print("Making new model.")
        model = Sequential()
        model.add(Dense(1024, input_shape=(784, ), activation='relu'))
        model.add(Dropout(0.3))
        model.add(Dense(256, activation='relu'))
        model.add(Dropout(0.3))
        model.add(Dense(10, activation='softmax'))

    return model


def train(model, train_x, train_y, epochs, test_x, test_y, model_file):
    model.compile(loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy'])

    print("Running for %d epochs."%(epochs))

    savemodel = ModelCheckpoint(model_file)
    stopmodel = EarlyStopping(min_delta=0.001, patience=10)

    model.fit(x=train_x, y=train_y, shuffle=True, batch_size=60, epochs=epochs, validation_data=(test_x, test_y), callbacks=[savemodel, stopmodel])

    print("Done training. Now evaluating.")
    loss, acc = model.evaluate(x=test_x, y=test_y)

    print("Final loss: %3.2f Final accuracy: %3.2f"%(loss, acc))


def load_mnist():
    (train_x, train_y), (test_x, test_y) = mnist.load_data()

    train_x = train_x.reshape(train_x.shape[0], 784)
    test_x = test_x.reshape(test_x.shape[0], 784)

    # Convert to floats
    train_x = train_x.astype('float32')
    test_x = test_x.astype('float32')

    train_x /= 255.0
    test_x /= 255.0

    # Convert to one-hot vectors
    train_y = to_categorical(train_y, 10)
    test_y = to_categorical(test_y, 10)

    return (train_x, train_y), (test_x, test_y)


def main():
    model = build_model('mnist.hd5')
    (train_x, train_y), (test_x, test_y) = load_mnist()
    train(model, train_x, train_y, 50, test_x, test_y, 'mnist.hd5')


if __name__ == '__main__':
    main()
相关标签: keras mnist