欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

迁移学习CNN图像分类模型 - 花朵图片分类

程序员文章站 2024-03-20 23:24:10
...

训练一个好的卷积神经网络模型进行图像分类不仅需要计算资源还需要很长的时间。特别是模型比较复杂和数据量比较大的时候。普通的电脑动不动就需要训练几天的时间。为了能够快速地训练好自己的花朵图片分类器,我们可以使用别人已经训练好的模型参数,在此基础之上训练我们的模型。这个便属于迁移学习。本文提供训练数据集和代码下载。
迁移学习CNN图像分类模型 - 花朵图片分类
迁移学习CNN图像分类模型 - 花朵图片分类

原理:卷积神经网络模型总体上可以分为两部分,前面的卷积层和后面的全连接层。卷积层的作用是图片特征的提取,全连接层作用是特征的分类。我们的思路便是在inception-v3网络模型上,修改全连接层,保留卷积层。卷积层的参数使用的是别人已经训练好的,全连接层的参数需要我们初始化并使用我们自己的数据来训练和学习。

迁移学习CNN图像分类模型 - 花朵图片分类

上面inception-v3模型图红色箭头前面部分是卷积层,后面是全连接层。我们需要修改修改全连接层,同时把模型的最终输出改为5。

由于这里使用了tensorflow框架,所以,我们需要获取上图红色箭头所在位置的张量BOTTLENECK_TENSOR_NAME(最后一个卷积层**函数的输出值,个数为2048)以及模型最开始的输入数据的张量JPEG_DATA_TENSOR_NAME。获取这两个张量的作用是,图片训练数据通过JPEG_DATA_TENSOR_NAME张量输入模型,通过BOTTLENECK_TENSOR_NAME张量获取通过卷积层之后的图片特征。

BOTTLENECK_TENSOR_SIZE = 2048
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'

通过下面的链接下载inception-v3模型,其中包含已经训练好的参数。

模型下载链接:地址

训练数据花朵图片下载:地址

通过下面的代码加载模型,同时获取上面所述的两个张量。

   # 读取已经训练好的Inception-v3模型。
    with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(
        graph_def, return_elements=[BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME])

由于我们模型的功能是对五种花进行分类,所以,我们需要修改全连接层,这里,我们只增加一个全连接层。全连接层的输入数据便是BOTTLENECK_TENSOR_NAME张量。

    # 定义一层全链接层
    with tf.name_scope('final_training_ops'):
        weights = tf.Variable(tf.truncated_normal([BOTTLENECK_TENSOR_SIZE, n_classes], stddev=0.001))
        biases = tf.Variable(tf.zeros([n_classes]))
        logits = tf.matmul(bottleneck_input, weights) + biases
        final_tensor = tf.nn.softmax(logits)

最后便是定义交叉熵损失函数。模型使用反向传播训练,而训练的参数并不是模型的所有参数,仅仅是全连接层的参数,卷积层的参数是不变的。

    # 定义交叉熵损失函数。
    cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=ground_truth_input)
    cross_entropy_mean = tf.reduce_mean(cross_entropy)
    train_step = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(cross_entropy_mean)

那么接下来的是如何给我们的模型输入数据了,这里提供了几个操作数据的函数。由于训练数据集比较小,先把所有的图片通过JPEG_DATA_TENSOR_NAME张量输入模型,然后获取BOTTLENECK_TENSOR_NAME张量的值并保存到硬盘中。在模型训练的时候,从硬盘中读取所保存的BOTTLENECK_TENSOR_NAME张量的值作为全连接层的输入数据。因为一张图片可能会被使用多次。

# 输入图片并获取`BOTTLENECK_TENSOR_NAME`张量的值
def get_or_create_bottleneck(sess, image_lists, label_name, index, category, jpeg_data_tensor, bottleneck_tensor)

# 从硬盘中读取`BOTTLENECK_TENSOR_NAME`张量的值,用于训练
def get_or_create_bottleneck(sess, image_lists, label_name, index, category, jpeg_data_tensor, bottleneck_tensor):

# 从硬盘中读取`BOTTLENECK_TENSOR_NAME`张量的值,用于测试。
def get_test_bottlenecks(sess, image_lists, n_classes, jpeg_data_tensor, bottleneck_tensor)

不到5分钟就可以训练好我们的模型,精确度还蛮高的。下图是本人运行的结果。

迁移学习CNN图像分类模型 - 花朵图片分类

源码地址:https://github.com/liangyihuai/my_tensorflow/tree/master/com/huai/converlution/transfer_learning