Tensorflow之基于slim训练自己的模型
假如我们需要从头开始训练一个图像识别的模型,我们可以使用tensorflow构建自己的图片分类模型,并将图片转换成tfrecord格式的文件。
tfrecord是tensorflow官方提供的一种文件类型。
这里补充下,关于tensorflow读取数据,官网给出了三种方法:
1、供给数据:在tensorflow程序运行的每一步,让python代码来供给数据
2、从文件读取数据:建立输入管线从文件中读取数据
3、预加载数据:如果数据量不太大,可以在程序中定义常量或者变量来保存所有的数据。
这里主要介绍一种比较通用、高效的数据读取方法,就是tensorflow官方推荐的标准格式:tfrecord。
tfrecord数据文件
tfrecord数据文件是一种将图像数据和标签统一存储的二进制文件,能更好的利用内存,在tensorflow中快速的复制,移动,读取,存储等。
tfrecord文件包含了tf.train.Example 协议缓冲区(protocol buffer,协议缓冲区包含了特征 Features)。
你可以写一段代码获取你的数据, 将数据填入到Example协议缓冲区(protocol buffer),将协议缓冲区序列化为一个字符串,并且通过tf.python_io.TFRecordWriter class写入到TFRecords文件。tensorflow/g3doc/how_tos/reading_data/convert_to_records.py就是这样的一个例子。
tf.train.Example的定义如下:
message Example {
Features features = 1;
};
message Features{
map<string,Feature> featrue = 1;
};
message Feature{
oneof kind{
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
从上述代码可以看出,tf.train.Example中包含了属性名称到取值的字典,其中属性名称为字符串,属性的取值可以为字符串(BytesList)、实数列表(FloatList)或者整数列表(Int64List)。
将数据保存为tfrecord格式
具体来说,首先需要给定tfrecord文件名称,并创建一个文件:
tfrecords_filename = './tfrecords/train.tfrecords'
writer = tf.python_io.TFRecordWriter(tfrecords_filename) # 创建.tfrecord文件,准备写入
之后就可以创建一个循环来依次写入数据:
for i in range(100):
img_raw = np.random.random_integers(0,255,size=(7,30)) # 创建7*30,取值在0-255之间随机数组
img_raw = img_raw.tostring()
example = tf.train.Example(features=tf.train.Features(
feature={
'label': tf.train.Feature(int64_list = tf.train.Int64List(value=[i])),
'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
example = tf.train.Example()这句将数据赋给了变量example(可以看到里面是通过字典结构实现的赋值),然后用writer.write(example.SerializeToString()) 这句实现写入。
值得注意的是赋值给example的数据格式。从前面tf.train.Example的定义可知,tfrecord支持整型、浮点数和二进制三种格式,分别是
tf.train.Feature(int64_list = tf.train.Int64List(value=[int_scalar]))
tf.train.Feature(bytes_list = tf.train.BytesList(value=[array_string_or_byte]))
tf.train.Feature(bytes_list = tf.train.FloatList(value=[float_scalar]))
例如图片等数组形式(array)的数据,可以保存为numpy array的格式,转换为string,然后保存到二进制格式的feature中。对于单个的数值(scalar),可以直接赋值。这里value=[×]的[]非常重要,也就是说输入的必须是列表(list)。当然,对于输入数据是向量形式的,可以根据数据类型(float还是int)分别保存。并且在保存的时候还可以指定数据的维数。
slim框架训练模型
下载slim 和inception v4模型
https://github.com/tensorflow/models/tree/master/research/slim
将slim下载后拷贝到project目录下,然后进行以下准备工作。
1.将图片放置到指定的目录下:
图片需要按照文件夹进行分类,文件夹名就是分类的名称,具体可以参考下图:
2.运行代码,转换格式
#导入相应的模块
import tensorflow as tf
import os
import random
import math
import sys
#划分验证集训练集
_NUM_TEST = 500
#random seed
_RANDOM_SEED = 0
#数据块
_NUM_SHARDS = 2
#数据集路径
DATASET_DIR = 'E:/SVN/Gavin/Learn/Python/pygame/slim/images/'
#标签文件
LABELS_FILENAME = 'E:/SVN/Gavin/Learn/Python/pygame/slim/images/labels.txt'
#定义tfrecord 的路径和名称
def _get_dataset_filename(dataset_dir,split_name,shard_id):
output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name,shard_id,_NUM_SHARDS)
return os.path.join(dataset_dir,output_filename)
#判断tfrecord文件是否存在
def _dataset_exists(dataset_dir):
for split_name in ['train','test']:
for shard_id in range(_NUM_SHARDS):
#定义tfrecord的路径名字
output_filename = _get_dataset_filename(dataset_dir,split_name,shard_id)
if not tf.gfile.Exists(output_filename):
return False
return True
#获取图片以及分类
def _get_filenames_and_classes(dataset_dir):
#数据目录
directories = []
#分类名称
class_names = []
for filename in os.listdir(dataset_dir):
#合并文件路径
path = os.path.join(dataset_dir,filename)
#判断路径是否是目录
if os.path.isdir(path):
#加入数据目录
directories.append(path)
#加入类别名称
class_names.append(filename)
photo_filenames = []
#循环分类的文件夹
for directory in directories:
for filename in os.listdir(directory):
path = os.path.join(directory,filename)
#将图片加入图片列表中
photo_filenames.append(path)
#返回结果
return photo_filenames ,class_names
def int64_feature(values):
if not isinstance(values,(tuple,list)):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
def bytes_feature(values):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
#图片转换城tfexample函数
def image_to_tfexample(image_data,image_format,class_id):
return tf.train.Example(features=tf.train.Features(feature={
'image/encoded': bytes_feature(image_data),
'image/format': bytes_feature(image_format),
'image/class/label': int64_feature(class_id)
}))
def write_label_file(labels_to_class_names,dataset_dir,filename=LABELS_FILENAME):
label_filename = os.path.join(dataset_dir,filename)
with tf.gfile.Open(label_filename,'w') as f:
for label in labels_to_class_names:
class_name = labels_to_class_names[label]
f.write('%d:%s\n' % (label, class_name))
#数据转换城tfrecorad格式
def _convert_dataset(split_name,filenames,class_names_to_ids,dataset_dir):
assert split_name in ['train','test']
#计算每个数据块的大小
num_per_shard = int(len(filenames) / _NUM_SHARDS)
with tf.Graph().as_default():
with tf.Session() as sess:
for shard_id in range(_NUM_SHARDS):
#定义tfrecord的路径名字
output_filename = _get_dataset_filename(dataset_dir,split_name,shard_id)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
#每个数据块开始的位置
start_ndx = shard_id * num_per_shard
#每个数据块结束的位置
end_ndx = min((shard_id+1) * num_per_shard,len(filenames))
for i in range(start_ndx,end_ndx):
try:
sys.stdout.write('\r>> Converting image %d/%d shard %d '% (i+1,len(filenames),shard_id))
sys.stdout.flush()
#读取图片
image_data = tf.gfile.FastGFile(filenames[i],'rb').read()
#获取图片的类别名称
#basename获取图片路径最后一个字符串
#dirname是除了basename之外的前面的字符串路径
class_name = os.path.basename(os.path.dirname(filenames[i]))
#获取图片的id
class_id = class_names_to_ids[class_name]
#生成tfrecord文件
example = image_to_tfexample(image_data,b'jpg',class_id)
#写入数据
tfrecord_writer.write(example.SerializeToString())
except IOError as e:
print ('could not read:',filenames[1])
print ('error:' , e)
print ('skip it \n')
sys.stdout.write('\n')
sys.stdout.flush()
if __name__ == '__main__':
#判断tfrecord文件是否存在
if _dataset_exists(DATASET_DIR):
print ('tfrecord exists')
else:
#获取图片以及分类
photo_filenames,class_names = _get_filenames_and_classes(DATASET_DIR)
#将分类的list转换成dictionary{‘animal':0,'flowers:1'}
class_names_to_ids = dict(zip(class_names,range(len(class_names))))
#切分数据为测试训练集
random.seed(_RANDOM_SEED)
random.shuffle(photo_filenames)
training_filenames = photo_filenames[_NUM_TEST:]
testing_filenames = photo_filenames[:_NUM_TEST]
#数据转换
_convert_dataset('train',training_filenames,class_names_to_ids,DATASET_DIR)
_convert_dataset('test',testing_filenames,class_names_to_ids,DATASET_DIR)
#输出lables文件
#与前面的 class_names_to_ids中的元素位置相反{0:'animal',1:'flowers'}
labels_to_class_names = dict(zip(range(len(class_names)),class_names))
write_label_file(labels_to_class_names,DATASET_DIR)
完成后,生成了以下文件,包括训练集数据块和测试集数据块,另外有个label标签文件
验证模型
接下来就是本文的重点了,上一步完成后我们已经得到tfrecord格式的文件了。下一步我们就要用这些文件进行分类验证测试。
1.在slim/datasets文件夹下,找到训练的数据集,由于写法比较 一致,我们只需要拷贝其中一个,比如训练flowers用的数据源,新建一个我们自己训练使用的数据集,命名为myimages(可*命名),将代码拷贝到新文件中,最后修改。
代码如下,需要修改的几个地方我作了中文标记。代码实现的其实就是读取tfrecord文件。
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Provides data for the flowers dataset.
The dataset scripts used to create the dataset can be found at:
tensorflow/models/research/slim/datasets/download_and_convert_flowers.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
from datasets import dataset_utils
slim = tf.contrib.slim
_FILE_PATTERN = 'image_%s_*.tfrecord' # 这里修改pattern,格式和生成tfrecord文件下的格式一致
SPLITS_TO_SIZES = {'train': 1026, 'validation': 50} #修改训练集和验证集图片数量
_NUM_CLASSES = 2 # 这里修改数据块个数
_ITEMS_TO_DESCRIPTIONS = {
'image': 'A color image of varying size.',
'label': 'A single integer between 0 and 4',
}
def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
"""Gets a dataset tuple with instructions for reading flowers.
Args:
split_name: A train/validation split name.
dataset_dir: The base directory of the dataset sources.
file_pattern: The file pattern to use when matching the dataset sources.
It is assumed that the pattern contains a '%s' string so that the split
name can be inserted.
reader: The TensorFlow reader type.
Returns:
A `Dataset` namedtuple.
Raises:
ValueError: if `split_name` is not a valid train/validation split.
"""
if split_name not in SPLITS_TO_SIZES:
raise ValueError('split name %s was not recognized.' % split_name)
if not file_pattern:
file_pattern = _FILE_PATTERN
file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
# Allowing None in the signature so that dataset_factory can use the default.
if reader is None:
reader = tf.TFRecordReader
keys_to_features = {
'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
'image/class/label': tf.FixedLenFeature(
[], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
}
items_to_handlers = {
'image': slim.tfexample_decoder.Image(),
'label': slim.tfexample_decoder.Tensor('image/class/label'),
}
decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)
labels_to_names = None
if dataset_utils.has_labels(dataset_dir):
labels_to_names = dataset_utils.read_label_file(dataset_dir)
return slim.dataset.Dataset(
data_sources=file_pattern,
reader=reader,
decoder=decoder,
num_samples=SPLITS_TO_SIZES[split_name],
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
num_classes=_NUM_CLASSES,
labels_to_names=labels_to_names)
注意,我们还需要修改 dataset_factory.py文件下的datasets_map字典,如下
datasets_map = {
'cifar10': cifar10,
'flowers': flowers,
'imagenet': imagenet,
'mnist': mnist,
'myimages': myimages, # 这里新增个人训练的数据集
}
其中,'cifar10','flowers','imagenet','mnist'是官方提供的数据集,而‘myimages’就是我们刚刚新建的文件,新增到datasets_map。3.在slim文件夹下新建一个model文件夹,用于保存训练生成的模型
4.如果标签是中文,修改slim/datasets/dataset_utils.py
sys.setdefaultencoding("utf-8") #中文标签,增加utf-8
5.在slim目录下编写执行训练数据的脚本。
python E:/SVN/Gavin/Learn/Python/pygame/slim/train_image_classifier.py ^
--dataset_name=myimages ^
--dataset_split_name=train ^
--train_dir= E:\SVN\Gavin\Learn\Python\pygame\slim\models ^
--dataset_dir=E:\SVN\Gavin\Learn\Python\pygame\slim\images ^
--batch_size=10 ^
--max_number_of_steps=10000 ^
--model_name=inception_v3 ^
pause
简单解释下:
train_dir 训练生成的模型存放位置
dataset_split_name=train 代表使用的是训练集,之前拆分为训练集和测试集
dataset_dir 训练图片存放位置
6.执行预测脚本,使用eval_image_classifier.py文件
如果使用的CPU,那么训练时间是比较漫长的。
推荐阅读
-
Tensorflow之基于slim训练自己的模型
-
tensorflow深度学习实战笔记(一):使用tensorflow slim自带的模型训练自己的数据
-
TensorFlow卷积神经网络之使用训练好的模型识别猫狗图片
-
使用TensorFlow提供的slim模型来训练数据模型供iOS使用
-
利用TensorFlow Object Detection API的预训练模型训练自己的数据
-
使用tensorflow搭建图像识别模型SSD之训练自己模型
-
机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)
-
TensorFlow卷积神经网络之使用训练好的模型识别猫狗图片
-
利用TensorFlow Object Detection API的预训练模型训练自己的数据