欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Tensorflow使用Keras的预训练模型

程序员文章站 2022-05-26 17:54:51
...

tensorflow.keras.applications是非常有用的库,里面有很多使用ImageNet训练好的模型,但我习惯使用TensorFlow的API来训练模型,所以,如何在TensorFlow中使用keras模型库呢?
参考:

  1. https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html
  2. http://zachmoshe.com/2017/11/11/use-keras-models-with-tf.html

数据集类

#! -*- coding: utf-8 -*-
from bases.base_dataset import BaseDataset
import tensorflow as tf


class PoliticDataset(BaseDataset):
    def __init__(self, filename, batch_size, image_size, train, prefetch_size, num_workers=8):
        self.filename = filename
        self.batch_size = batch_size
        self.image_size = image_size
        self.train = train
        self.prefetch_size = prefetch_size
        self.num_workers = num_workers
        self._init_dataset()

    def _init_dataset(self):
        items = []
        for line in open(self.filename):
            image_path, image_label = line.strip("\n").split()
            items.append((image_path, int(image_label)))
        self.names, self.labels = zip(*items)

    def _transform(self, filename, label):
        image_string = tf.read_file(filename)
        image_decoded = tf.image.decode_jpeg(image_string)
        image_resized = tf.image.resize_images(image_decoded, self.image_size)
        return image_resized, label

    def __len__(self):
        return len(self.names)

    def dataset(self):
        dataset = tf.data.Dataset.from_tensor_slices((tf.constant(self.names), tf.constant(self.labels)))

        if self.train:
            dataset = dataset.shuffle(buffer_size=len(self.names))

        dataset = dataset.map(self._transform, num_parallel_calls=self.num_workers).batch(self.batch_size)

        if self.train:
            dataset = dataset.repeat()

        dataset = dataset.prefetch(self.prefetch_size)
        dataset.make_one_shot_iterator()
        return dataset

训练模型

#! -*- coding: utf-8 -*-
from tensorflow import keras as k
from dataset import PoliticDataset
import tensorflow as tf
import numpy as np
import shutil
import os

# 数据集配置
TRAIN_FILE = "data/train.txt"
TEST_FILE = "data/test.txt"
TRAIN_BATCH_SIZE = 32
TEST_BATCH_SIZE = 64
PREFETCH_SIZE = 64

# 图像参数配置
IMAGE_HEIGHT = 224
IMAGE_WIDTH = 224
IMAGE_CHANS = 3
IMAGE_SIZE = (IMAGE_HEIGHT, IMAGE_WIDTH)

# 超参数配置
EPOCHS = 1000
STEPS_PER_ITER = 100
TEST_BY_EPOCHS = 10
LR = 0.001
MAX_TO_KEEP = 5

# 模型和日志保存目录
SAVE_ROOT = "../experiments/"
PROJECT_NAME = "politic_simpleNet_v1"

# GPU资源配置
gpu_config = tf.GPUOptions(
    allow_growth=True,
    per_process_gpu_memory_fraction=0.99,
)
gpu_config = tf.ConfigProto(
    log_device_placement=False,
    allow_soft_placement=True,
    gpu_options=gpu_config,
)


def writer_log(writer, tag, loss, acc, global_step):
    """写入日志"""
    summary = tf.summary.Summary(
        value=[
            tf.summary.Summary.Value(tag=tag + "/loss", simple_value=loss),
            tf.summary.Summary.Value(tag=tag + "/acc", simple_value=acc)
        ]
    )
    writer.add_summary(summary, global_step)


def mkdir(dir_name, delete=False):
    if os.path.exists(dir_name):
        if delete:
            shutil.rmtree(dir_name)
            print("Create %s" % dir_name)
            os.makedirs(dir_name)
            print("Create succeed.")
    else:
        print("Create %s" % dir_name)
        os.makedirs(dir_name)
        print("Create succeed.")


# 创建文件夹
LOG_DIR = os.path.join(SAVE_ROOT, PROJECT_NAME, "logs/")  # 日志保存目录
CHECKPOINT_DIR = os.path.join(SAVE_ROOT, PROJECT_NAME, "checkpoints/")  # 模型保存目录
mkdir(LOG_DIR, delete=True)
mkdir(CHECKPOINT_DIR, delete=False)


def main():
    # 数据集
    trainset = PoliticDataset(TRAIN_FILE, TRAIN_BATCH_SIZE, IMAGE_SIZE, train=True, prefetch_size=PREFETCH_SIZE)
    print("Train dataset batch size: ", len(trainset) / TRAIN_BATCH_SIZE)
    testset = PoliticDataset(TEST_FILE, TEST_BATCH_SIZE, IMAGE_SIZE, train=False, prefetch_size=PREFETCH_SIZE)
    print("Test dataset batch size: ", len(testset) / TEST_BATCH_SIZE)

    trainiter = trainset.dataset().make_one_shot_iterator()
    testiter = testset.dataset().make_initializable_iterator()

    train_next = trainiter.get_next()
    test_next = testiter.get_next()

    # 全局变量
    global_step = tf.Variable(0, trainable=False, name="global_step")
    global_epoch = tf.Variable(0, trainable=False, name="global_epoch")

    # 全局操作
    global_epoch_increment = tf.assign_add(global_epoch, 1)

    # 定义占位符
    x = tf.placeholder(tf.float32, (None, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANS), name="inputs")
    y = tf.placeholder(tf.int64, (None, ), name="label")
    # training = tf.placeholder(tf.bool)

    # 模型
    out = k.applications.ResNet50(include_top=False, pooling="avg", input_shape=(224, 224, 3))(x)
    outputs = k.layers.Dense(3)(out)
    print(outputs.shape)

    # 训练
    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=outputs))
    # loss = tf.reduce_mean(k.losses.sparse_categorical_crossentropy(y, outputs))
    corrects = tf.equal(tf.argmax(outputs, 1), y)
    accuracy = tf.reduce_mean(tf.cast(corrects, tf.float32))
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = tf.train.AdamOptimizer(LR).minimize(loss, global_step)

    # 模型保存器
    saver = tf.train.Saver(max_to_keep=MAX_TO_KEEP)
    best_loss = 1.0

    # 变量初始化
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

    # 定义会话
    sess = tf.Session(config=gpu_config)

    # 变量初始化
    sess.run(init_op)

    # 日志记录
    train_writer = tf.summary.FileWriter(LOG_DIR + "train/", sess.graph)
    test_writer = tf.summary.FileWriter(LOG_DIR + "test/")

    while global_epoch.eval(sess) < EPOCHS:
        train_losses = []
        test_losses = []
        train_accs = []
        test_accs = []

        # 更改状态,主要对BatchNormalization和Dropout有用
        k.backend.set_learning_phase(1)

        for step_idx in range(STEPS_PER_ITER):
            # print("Train learning phase", k.backend.learning_phase())

            x_train, y_train = sess.run(train_next)
            _, train_batch_loss, train_batch_acc = sess.run([train_op, loss, accuracy],
                                                            feed_dict={x: x_train, y: y_train,
                                                                       # training: True,
                                                                       # k.backend.learning_phase(): 1
                                                                       })
            train_losses.append(train_batch_loss)
            train_accs.append(train_batch_acc)
        train_loss = np.mean(train_losses)
        train_acc = np.mean(train_accs)
        writer_log(train_writer, "train", train_loss, train_acc, global_epoch.eval(sess))
        print("Train epoch: {}, loss: {:.4f}, acc: {:.4f}".format(global_epoch.eval(sess), train_loss, train_acc))

        if global_epoch.eval(sess) % TEST_BY_EPOCHS == 0:
            # 测试集迭代器初始化
            sess.run(testiter.initializer)

            # 更改状态
            k.backend.set_learning_phase(0)

            while True:
                # print("Test learning phase", k.backend.learning_phase())

                try:
                    x_test, y_test = sess.run(test_next)
                except tf.errors.OutOfRangeError:
                    break
                test_batch_loss, test_batch_acc = sess.run([loss, accuracy],
                                                           feed_dict={x: x_test, y: y_test,
                                                                      # training: False,
                                                                      # k.backend.learning_phase(): 0
                                                                      })
                test_losses.append(test_batch_loss)
                test_accs.append(test_batch_acc)
            test_loss = np.mean(test_losses)
            test_acc = np.mean(test_accs)
            writer_log(test_writer, "test", test_loss, test_acc, global_epoch.eval(sess))
            print("Test epoch: {}, loss: {:.4f}, acc: {:.4f}".format(global_epoch.eval(sess), test_loss, test_acc))

            # 保存模型
            if test_loss < best_loss:
                print("* Saving model ...")
                saver.save(sess, CHECKPOINT_DIR, global_step=global_step)
                best_loss = test_loss
                print("* Model saved.")

        # epoch+1
        sess.run(global_epoch_increment)


if __name__ == '__main__':
    main()