欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Tensorflow之基于slim训练自己的模型

程序员文章站 2024-03-15 09:17:29
...

    假如我们需要从头开始训练一个图像识别的模型,我们可以使用tensorflow构建自己的图片分类模型,并将图片转换成tfrecord格式的文件。

tfrecord是tensorflow官方提供的一种文件类型。

这里补充下,关于tensorflow读取数据,官网给出了三种方法:
1、供给数据:在tensorflow程序运行的每一步,让python代码来供给数据
2、从文件读取数据:建立输入管线从文件中读取数据
3、预加载数据:如果数据量不太大,可以在程序中定义常量或者变量来保存所有的数据。

这里主要介绍一种比较通用、高效的数据读取方法,就是tensorflow官方推荐的标准格式:tfrecord。

tfrecord数据文件


tfrecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等。

tfrecord文件包含了tf.train.Example 协议缓冲区(protocol buffer,协议缓冲区包含了特征 Features)。

你可以写一段代码获取你的数据, 将数据填入到Example协议缓冲区(protocol buffer),将协议缓冲区序列化为一个字符串,并且通过tf.python_io.TFRecordWriter class写入到TFRecords文件。tensorflow/g3doc/how_tos/reading_data/convert_to_records.py就是这样的一个例子。

tf.train.Example的定义如下:

message Example {
 Features features = 1;
};

message Features{
 map<string,Feature> featrue = 1;
};

message Feature{
    oneof kind{
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }
};

从上述代码可以看出,tf.train.Example中包含了属性名称到取值的字典,其中属性名称为字符串,属性的取值可以为字符串(BytesList)、实数列表(FloatList)或者整数列表(Int64List)。

将数据保存为tfrecord格式

具体来说,首先需要给定tfrecord文件名称,并创建一个文件:

tfrecords_filename = './tfrecords/train.tfrecords'
writer = tf.python_io.TFRecordWriter(tfrecords_filename) # 创建.tfrecord文件,准备写入

之后就可以创建一个循环来依次写入数据:

    for i in range(100):
        img_raw = np.random.random_integers(0,255,size=(7,30)) # 创建7*30,取值在0-255之间随机数组
        img_raw = img_raw.tostring()
        example = tf.train.Example(features=tf.train.Features(
                feature={
                'label': tf.train.Feature(int64_list = tf.train.Int64List(value=[i])),     
                'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
                }))
        writer.write(example.SerializeToString()) 

    writer.close()

example = tf.train.Example()这句将数据赋给了变量example(可以看到里面是通过字典结构实现的赋值),然后用writer.write(example.SerializeToString()) 这句实现写入。

值得注意的是赋值给example的数据格式。从前面tf.train.Example的定义可知,tfrecord支持整型、浮点数和二进制三种格式,分别是

tf.train.Feature(int64_list = tf.train.Int64List(value=[int_scalar]))
tf.train.Feature(bytes_list = tf.train.BytesList(value=[array_string_or_byte]))
tf.train.Feature(bytes_list = tf.train.FloatList(value=[float_scalar]))

例如图片等数组形式(array)的数据,可以保存为numpy array的格式,转换为string,然后保存到二进制格式的feature中。对于单个的数值(scalar),可以直接赋值。这里value=[×]的[]非常重要,也就是说输入的必须是列表(list)。当然,对于输入数据是向量形式的,可以根据数据类型(float还是int)分别保存。并且在保存的时候还可以指定数据的维数。

slim框架训练模型

下载slim 和inception v4模型

https://github.com/tensorflow/models/tree/master/research/slim

将slim下载后拷贝到project目录下,然后进行以下准备工作。

1.将图片放置到指定的目录下:

图片需要按照文件夹进行分类,文件夹名就是分类的名称,具体可以参考下图:

Tensorflow之基于slim训练自己的模型

这里我将分类数据集放到images目录下,images是在slim目录下新建的文件夹。


2.运行代码,转换格式 

#导入相应的模块
import tensorflow as tf
import os
import random
import math
import sys
#划分验证集训练集
_NUM_TEST = 500
#random seed
_RANDOM_SEED = 0
#数据块
_NUM_SHARDS = 2
#数据集路径
DATASET_DIR = 'E:/SVN/Gavin/Learn/Python/pygame/slim/images/'
#标签文件
LABELS_FILENAME = 'E:/SVN/Gavin/Learn/Python/pygame/slim/images/labels.txt'
#定义tfrecord 的路径和名称
def _get_dataset_filename(dataset_dir,split_name,shard_id):
    output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name,shard_id,_NUM_SHARDS)
    return os.path.join(dataset_dir,output_filename)
#判断tfrecord文件是否存在
def _dataset_exists(dataset_dir):
    for split_name in ['train','test']:
        for shard_id in range(_NUM_SHARDS):
            #定义tfrecord的路径名字
            output_filename = _get_dataset_filename(dataset_dir,split_name,shard_id)
        if not tf.gfile.Exists(output_filename):
            return False
    return True
#获取图片以及分类
def _get_filenames_and_classes(dataset_dir):
    #数据目录
    directories = []
    #分类名称
    class_names = []
    for filename in os.listdir(dataset_dir):
        #合并文件路径
        path = os.path.join(dataset_dir,filename)
        #判断路径是否是目录
        if os.path.isdir(path):
            #加入数据目录
            directories.append(path)
            #加入类别名称
            class_names.append(filename)
    photo_filenames = []
    #循环分类的文件夹
    for directory in directories:
        for filename in os.listdir(directory):
            path = os.path.join(directory,filename)
            #将图片加入图片列表中
            photo_filenames.append(path)
    #返回结果
    return photo_filenames ,class_names
def int64_feature(values):
    if not isinstance(values,(tuple,list)):
        values = [values]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
def bytes_feature(values):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
#图片转换城tfexample函数
def image_to_tfexample(image_data,image_format,class_id):
    return tf.train.Example(features=tf.train.Features(feature={
        'image/encoded': bytes_feature(image_data),
        'image/format': bytes_feature(image_format),
        'image/class/label': int64_feature(class_id)
    }))
def write_label_file(labels_to_class_names,dataset_dir,filename=LABELS_FILENAME):
    label_filename = os.path.join(dataset_dir,filename)
    with tf.gfile.Open(label_filename,'w') as f:
        for label in labels_to_class_names:
            class_name = labels_to_class_names[label]
            f.write('%d:%s\n' % (label, class_name))
#数据转换城tfrecorad格式
def _convert_dataset(split_name,filenames,class_names_to_ids,dataset_dir):
    assert split_name in ['train','test']
    #计算每个数据块的大小
    num_per_shard = int(len(filenames) / _NUM_SHARDS)
    with tf.Graph().as_default():
        with tf.Session() as sess:
            for shard_id in range(_NUM_SHARDS):
            #定义tfrecord的路径名字
                output_filename = _get_dataset_filename(dataset_dir,split_name,shard_id)
                with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
                    #每个数据块开始的位置
                    start_ndx = shard_id * num_per_shard
                    #每个数据块结束的位置
                    end_ndx = min((shard_id+1) * num_per_shard,len(filenames))
                    for i in range(start_ndx,end_ndx):
                        try:
                            sys.stdout.write('\r>> Converting image %d/%d shard %d '% (i+1,len(filenames),shard_id))
                            sys.stdout.flush()
                            #读取图片
                            image_data = tf.gfile.FastGFile(filenames[i],'rb').read()
                            #获取图片的类别名称
                            #basename获取图片路径最后一个字符串
                            #dirname是除了basename之外的前面的字符串路径
                            class_name = os.path.basename(os.path.dirname(filenames[i]))
                            #获取图片的id
                            class_id = class_names_to_ids[class_name]
                            #生成tfrecord文件
                            example = image_to_tfexample(image_data,b'jpg',class_id)
                            #写入数据
                            tfrecord_writer.write(example.SerializeToString())
                        except IOError  as e:
                            print ('could not read:',filenames[1])
                            print ('error:' , e)
                            print ('skip it \n')
    sys.stdout.write('\n')
    sys.stdout.flush()

if __name__ == '__main__':
    #判断tfrecord文件是否存在
    if _dataset_exists(DATASET_DIR):
        print ('tfrecord exists')
    else:
        #获取图片以及分类
        photo_filenames,class_names = _get_filenames_and_classes(DATASET_DIR)
        #将分类的list转换成dictionary{‘animal':0,'flowers:1'}
        class_names_to_ids = dict(zip(class_names,range(len(class_names))))
        #切分数据为测试训练集
        random.seed(_RANDOM_SEED)
        random.shuffle(photo_filenames)
        training_filenames = photo_filenames[_NUM_TEST:]
        testing_filenames = photo_filenames[:_NUM_TEST]
        #数据转换
        _convert_dataset('train',training_filenames,class_names_to_ids,DATASET_DIR)
        _convert_dataset('test',testing_filenames,class_names_to_ids,DATASET_DIR)
        #输出lables文件
        #与前面的 class_names_to_ids中的元素位置相反{0:'animal',1:'flowers'}
        labels_to_class_names = dict(zip(range(len(class_names)),class_names))
        write_label_file(labels_to_class_names,DATASET_DIR)

完成后,生成了以下文件,包括训练集数据块和测试集数据块,另外有个label标签文件

Tensorflow之基于slim训练自己的模型

验证模型

接下来就是本文的重点了,上一步完成后我们已经得到tfrecord格式的文件了。下一步我们就要用这些文件进行分类验证测试。

1.在slim/datasets文件夹下,找到训练的数据集,由于写法比较 一致,我们只需要拷贝其中一个,比如训练flowers用的数据源,新建一个我们自己训练使用的数据集,命名为myimages(可*命名),将代码拷贝到新文件中,最后修改。

Tensorflow之基于slim训练自己的模型

代码如下,需要修改的几个地方我作了中文标记。代码实现的其实就是读取tfrecord文件。

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides data for the flowers dataset.

The dataset scripts used to create the dataset can be found at:
tensorflow/models/research/slim/datasets/download_and_convert_flowers.py
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tensorflow as tf

from datasets import dataset_utils

slim = tf.contrib.slim

_FILE_PATTERN = 'image_%s_*.tfrecord' # 这里修改pattern,格式和生成tfrecord文件下的格式一致

SPLITS_TO_SIZES = {'train': 1026, 'validation': 50} #修改训练集和验证集图片数量

_NUM_CLASSES = 2 # 这里修改数据块个数

_ITEMS_TO_DESCRIPTIONS = {
    'image': 'A color image of varying size.',
    'label': 'A single integer between 0 and 4',
}


def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
  """Gets a dataset tuple with instructions for reading flowers.

  Args:
    split_name: A train/validation split name.
    dataset_dir: The base directory of the dataset sources.
    file_pattern: The file pattern to use when matching the dataset sources.
      It is assumed that the pattern contains a '%s' string so that the split
      name can be inserted.
    reader: The TensorFlow reader type.

  Returns:
    A `Dataset` namedtuple.

  Raises:
    ValueError: if `split_name` is not a valid train/validation split.
  """
  if split_name not in SPLITS_TO_SIZES:
    raise ValueError('split name %s was not recognized.' % split_name)

  if not file_pattern:
    file_pattern = _FILE_PATTERN
  file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

  # Allowing None in the signature so that dataset_factory can use the default.
  if reader is None:
    reader = tf.TFRecordReader

  keys_to_features = {
      'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
      'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
      'image/class/label': tf.FixedLenFeature(
          [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
  }

  items_to_handlers = {
      'image': slim.tfexample_decoder.Image(),
      'label': slim.tfexample_decoder.Tensor('image/class/label'),
  }

  decoder = slim.tfexample_decoder.TFExampleDecoder(
      keys_to_features, items_to_handlers)

  labels_to_names = None
  if dataset_utils.has_labels(dataset_dir):
    labels_to_names = dataset_utils.read_label_file(dataset_dir)

  return slim.dataset.Dataset(
      data_sources=file_pattern,
      reader=reader,
      decoder=decoder,
      num_samples=SPLITS_TO_SIZES[split_name],
      items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
      num_classes=_NUM_CLASSES,
      labels_to_names=labels_to_names)



注意,我们还需要修改 dataset_factory.py文件下的datasets_map字典,如下

datasets_map = {
    'cifar10': cifar10,
    'flowers': flowers,
    'imagenet': imagenet,
    'mnist': mnist,
    'myimages': myimages,  # 这里新增个人训练的数据集
}
其中,'cifar10','flowers','imagenet','mnist'是官方提供的数据集,而‘myimages’就是我们刚刚新建的文件,新增到datasets_map。


3.在slim文件夹下新建一个model文件夹,用于保存训练生成的模型

4.如果标签是中文,修改slim/datasets/dataset_utils.py

sys.setdefaultencoding("utf-8")  #中文标签,增加utf-8  

5.在slim目录下编写执行训练数据的脚本。

python E:/SVN/Gavin/Learn/Python/pygame/slim/train_image_classifier.py ^
--dataset_name=myimages ^
--dataset_split_name=train ^
--train_dir= E:\SVN\Gavin\Learn\Python\pygame\slim\models ^
--dataset_dir=E:\SVN\Gavin\Learn\Python\pygame\slim\images ^
--batch_size=10 ^
--max_number_of_steps=10000 ^
--model_name=inception_v3 ^
pause

简单解释下:

train_dir 训练生成的模型存放位置

dataset_split_name=train 代表使用的是训练集,之前拆分为训练集和测试集
dataset_dir  训练图片存放位置

6.执行预测脚本,使用eval_image_classifier.py文件

如果使用的CPU,那么训练时间是比较漫长的。





















相关标签: tensorflow slim