迁移网络的实现原理
补发一段对于迁移网络的学习笔记。
手动训练一些层数较深的神经网络会花费大量的时间。我们可以利用一些常见的神经网络模型,使用已经训练好的参数,对图像的特征进行提取,这样来实现避免手动训练参数而花费太多时间的作用。
函数主题非常简单,以Inception-v3来作为特征提取网络,我们将待训练图片通过Inception-v3,得到特征向量,使用一个全连接层将特征向量与label标签链接起来,这时我们需要训练的就只有一个全连接层。
在使用该迁移网络对新的图片进行判断时,只需要获得特征向量后再经过全连接层即可。
这里使用的是谷歌提供的训练好的Inception-v3模型: [https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip](https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip)
案例使用的数据集: [http://download.tensorflow.org/example_images/flower_photos.tgz](http://download.tensorflow.org/example_images/flower_photos.tgz)
数据集文件解压后,包含5个子文件夹,子文件夹的名称为花的名称,代表了不同的类别。平均每一种花有734张图片,图片是RGB色彩模式,大小也不相同。
主要代码:
def main():
image_list = create_image_list(TEST_PERCENTAGE, VALIDATION_PERCENTAGE)
# 从图片文件夹中读取出图片
n_classes = len(image_list.keys())
# 分类个数,这里应该是5
with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# 从文件中提取模型并还原成graph
bottleneck_tensor, jpeg_data_tensor = tf.import_graph_def(graph_def, return_elements=[
BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME
])
# 从Inception-v3中得到获取特征和label的张量。
bottleneck_input = tf.placeholder(tf.float32, [None, BOTTLENECK_TENSOR_SIZE],
name='BottleneckInputPlaceholder')
# 特征向量
ground_truth_input = tf.placeholder(tf.float32, [None, n_classes], name='GroundTruthInput')
# 真实正确的label值
with tf.name_scope('final_training_ops'):
weights = tf.Variable(tf.truncated_normal(
[BOTTLENECK_TENSOR_SIZE, n_classes], stddev=0.001
))
biases = tf.Variable(tf.zeros([n_classes]))
logits = tf.matmul(bottleneck_input, weights) + biases
final_tensor = tf.nn.softmax(logits)
# 一个全连接层
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=ground_truth_input)
cross_entropy_mean = tf.reduce_mean(cross_entropy)
train_step = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(cross_entropy_mean)
#训练
with tf.name_scope('evaluation'):
correct_prediction = tf.equal(tf.arg_max(final_tensor, 1),
tf.arg_max(ground_truth_input, 1))
evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#正确率
with tf.Session() as sess:
init = tf.initialize_all_variables()
sess.run(init)
for i in range(STEPS):
train_bottlenecks, train_ground_truth = get_random_cached_bottlenecks(
sess, n_classes, image_list, BATCH, 'training', jpeg_data_tensor, bottleneck_tensor
)
# 利用从Inception-v3中提取出的张量计算训练组图片的label和特征
sess.run(train_step, feed_dict={bottleneck_input:train_bottlenecks,
ground_truth_input:train_ground_truth})
if i%100 == 0 or i+1 == STEPS:
validation_bottles, validation_ground_truth = get_random_cached_bottlenecks(
sess, n_classes, image_list, BATCH, 'validation', jpeg_data_tensor, bottleneck_tensor
)
# 利用从Inception-v3中提取出的张量计算确认组图片的label和特征
validation_accuracy = sess.run(evaluation_step, feed_dict={
bottleneck_input:validation_bottles,
ground_truth_input:validation_ground_truth
})
print('Step %d: Validation accuracy on random sampled %d examples =%.lf%%' %
(i, BATCH, validation_accuracy*100))
# 计算出正确率后输出
test_bottlenecks, test_ground_truth = get_test_bottlenecks(
sess, image_list, n_classes, jpeg_data_tensor, bottleneck_tensor
)
# 利用从Inception-v3中提取出的张量计算测试组图片的label和特征
test_accuracy = sess.run(evaluation_step, feed_dict={bottleneck_input:test_bottlenecks,
ground_truth_input:test_ground_truth})
print('Final test accuracy = %.lf%%' % (test_accuracy*100))
到这里main函数的功能已经实现了,我们分类出了样本数据,从Inception-v3中提取出了计算特征的张量,创建了一个全连接网络,并将特征作为输入传入全连接层,进行训练并求出正确率。
现在需要做的只是把需要的函数补齐。需要的函数有:
- 从图片文件夹中读取图片: create_image_list
- 利用从Inception-v3中提取出的张量计算组图片的label和特征 :get_random_cached_bottlenecks
从图片文件夹中读取图片 create_image_list:
传入参数测试集和确认集各占的百分比,随机划取相应比例的图片数据来进行训练集,测试集,确认集的标记。如图,取一个0-100的随机数,如果在0-10区间内,当前图片归为确定集,如果在10-20区间内,当前图片归为测试集,其余的80%的任意取值,都归为训练集。
最终result的结构:
具体代码如下:
VALIDATION_PERCENTAGE = 10
TEST_PERCENTAGE = 10
INPUT_DATA = 'D://python//flower_photos//flower_photos'
def create_image_list(test_percentage, validation_percentage):
result = {}
sub_dirs = [x[0] for x in os.walk(INPUT_DATA)]
is_root_dir = True
for sub_dir in sub_dirs:
if is_root_dir:
is_root_dir = False
continue
extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
file_list = []
dir_name = os.path.basename(sub_dir)
for extionsion in extensions:
file_glob = os.path.join(INPUT_DATA, dir_name, '*.'+extionsion)
file_list.extend(glob.glob(file_glob))
if not file_list: continue
label_name = dir_name.lower()
training_images = []
testing_images = []
validation_images = []
for file_name in file_list:
base_name = os.path.basename(file_name)
chance = np.random.randint(100)
if chance < validation_percentage:
validation_images.append(base_name)
elif chance < (test_percentage + validation_percentage):
testing_images.append(base_name)
else:
training_images.append(base_name)
result[label_name] = {
'dir':dir_name,
'training':training_images,
'testing':testing_images,
'validation':validation_images
}
return result
利用张量计算图片的label和特征 :get_random_cached_bottlenecks
我们可以使用函数嵌套,通过传递不同的参数,如“training”,"validation",来获取不同类别的样本图片,并将样本图片进行处理,得到处理后的特征矩阵。函数的主要步骤为:
- 随机选取一个类别
- 在该类别的category中随机选取一个图片
- 计算该图片的特征
- 重复前3步how_many次
def get_random_cached_bottlenecks(sess, n_classes, image_lists, how_many, category,
jpeg_data_tensor, bottleneck_tensor):
'''
:param n_classes: 分类个数,这里应该为5
:param image_lists: 图片分类和位置信息,由函数create_image_list求得
:param how_many: BATCH数目,每次需要抽取的样本个数
:param category: 需要获取的类别:test/train/validation
:param jpeg_data_tensor: 数据输入张量
:param bottleneck_tensor: 特征生成张量
:return: how_many个category类型的图像数据经过处理后的特征矩阵和分类矩阵
'''
bottlenecks = []
ground_truths = []
for _ in range(how_many):
label_index = random.randrange(n_classes)
# 类别选取
label_name = list(image_lists.keys())[label_index]
image_index = random.randrange(65536)
# 图片选取
bottleneck = get_or_create_bottleneck(sess,image_lists,label_name,
image_index, category,
jpeg_data_tensor, bottleneck_tensor)
# 特征计算
ground_truth = np.zeros(n_classes, dtype=np.float32)
ground_truth[label_index] = 1.0
# label的one-hot编码
bottlenecks.append(bottleneck)
ground_truths.append(ground_truth)
return bottlenecks, ground_truths
已知图片和张量计算特征:get_or_create_bottleneck
这里采用的思想是:
为了节省程序运行时间,每一个图片的特征提取结果都存放在本地的txt文件中,再次运行程序的时候检查是否有对应图片的特征文件,如果有,直接读取存储好的特征信息,否则,从tensor中计算出特征矩阵,然后存放在本地,方便下次直接读取。当然也可以直接求解后使用。
def get_bottleneck_path(image_lists, label_name, index, category):
return get_image_path(image_lists, CACHE_DIR, label_name, index, category) + '.txt'
def run_bottleneck_on_image(sess, image_data, image_data_tensor, bottleneck_tensor):
bottleneck_values = sess.run(bottleneck_tensor, {image_data_tensor:image_data})
bottleneck_values = np.squeeze(bottleneck_values)
return bottleneck_values
def get_or_create_bottleneck(sess, image_lists, label_name, index,
category, jepg_data_tensor, bottleneck_tensor):
label_lists = image_lists[label_name]
sub_dir = label_lists['dir']
sub_dir_path = os.path.join(CACHE_DIR, sub_dir) # 将多个路径组合返回
print("sub_dir_path: ", sub_dir_path)
if not os.path.exists(sub_dir_path): # 文件是否存在
os.makedirs(sub_dir_path) # 创建
bottleneck_path = get_bottleneck_path(image_lists, label_name, index, category)
print("bottleneck_path: ", bottleneck_path)
if not os.path.exists(bottleneck_path):
image_path = get_image_path(image_lists, INPUT_DATA, label_name, index, category)
image_data = gfile.FastGFile(image_path, 'rb').read()
# 根据文件路径读取图片信息
bottleneck_values = run_bottleneck_on_image(sess, image_data, jepg_data_tensor, bottleneck_tensor)
# 图片特征信息提取
bottleneck_string = ','.join(str(x) for x in bottleneck_values)
with open(bottleneck_path, 'w') as bottleneck_file:
bottleneck_file.write(bottleneck_string)
# 图片特征信息保存。
else:
with open(bottleneck_path, 'r') as bottleneck_file:
bottleneck_string = bottleneck_file.read()
bottleneck_values = [float(x) for x in bottleneck_string.split(',')]
return bottleneck_values
最终运行结果:以上是根据《TensorFlow 实战Google深度学习框架》中的代码给出的解析,鉴于函数嵌套比较多,可读性较差,以下是我自己整理的代码,便于了解主要的处理步骤。
import glob import tensorflow as tf import os.path import numpy as np import random from tensorflow.python.platform import gfile BOTTLENECK_TENSOR_SIZE = 2048 BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0' JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0' MODEL_DIR = 'D://python//Inception_dec_2015' MODEL_FILE = 'tensorflow_inception_graph.pb' INPUT_DATA = 'D://python//flower_photos//flower_photos' VALIDATION_PERCENTAGE = 10 TEST_PERCENTAGE = 10 LEARNING_RATE = 0.01 STEPS = 4000 BATCH = 100 def create_image_list(test_percentage, validation_percentage): result = {} sub_dirs = [x[0] for x in os.walk(INPUT_DATA)] is_root_dir = True for sub_dir in sub_dirs: if is_root_dir: is_root_dir = False continue extensions = ['jpg', 'jpeg', 'JPG', 'JPEG'] file_list = [] dir_name = os.path.basename(sub_dir) for extionsion in extensions: file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extionsion) file_list.extend(glob.glob(file_glob)) if not file_list: continue label_name = dir_name.lower() training_images = [] testing_images = [] validation_images = [] for file_name in file_list: base_name = os.path.basename(file_name) chance = np.random.randint(100) if chance < validation_percentage: validation_images.append(base_name) elif chance < (test_percentage + validation_percentage): testing_images.append(base_name) else: training_images.append(base_name) result[label_name] = { 'dir': dir_name, 'training': training_images, 'testing': testing_images, 'validation': validation_images } return result def get_necks(sess, n_classes, image_list, how_many, category, tensor_data, tensor_neck): bottlenecks = [] ground_truths = [] for _ in range(how_many): label_index = random.randrange(n_classes) label_name = list(image_list.keys())[label_index] image_index = random.randrange(65536) label_lists = image_list[label_name] category_list = label_lists[category] mod_index = image_index % len(category_list) image_path = INPUT_DATA + "//" + label_name + "//" + category_list[mod_index] # 获得图片路径 image_data = gfile.FastGFile(image_path, 'rb').read() neck_values = sess.run(tensor_neck, {tensor_data: image_data}) neck_values = np.squeeze(neck_values) # 获取特征信息 ground_truth = np.zeros(n_classes, dtype=np.float32) ground_truth[label_index] = 1.0 bottlenecks.append(neck_values) ground_truths.append(ground_truth) return bottlenecks, ground_truths def main(): image_list = create_image_list(TEST_PERCENTAGE, VALIDATION_PERCENTAGE) n_classes = len(image_list.keys()) with gfile.FastGFile(os.path.join(MODEL_DIR, MODEL_FILE), 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) tensor_neck, tensor_data = tf.import_graph_def(graph_def, return_elements=[ BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME ]) input_neck = tf.placeholder(tf.float32, [None, BOTTLENECK_TENSOR_SIZE],name='BottleInput') input_truth = tf.placeholder(tf.float32, [None, n_classes], name='TruthInput') with tf.name_scope('final_training_ops'): weights = tf.Variable(tf.truncated_normal( [BOTTLENECK_TENSOR_SIZE, n_classes], stddev=0.001 )) biases = tf.Variable(tf.zeros([n_classes])) logits = tf.matmul(input_neck, weights) + biases final_tensor = tf.nn.softmax(logits) cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=input_truth) cross_entropy_mean = tf.reduce_mean(cross_entropy) train_step = tf.train.GradientDescentOptimizer(LEARNING_RATE).minimize(cross_entropy_mean) with tf.name_scope('evaluation'): correct_prediction = tf.equal(tf.arg_max(final_tensor, 1), tf.arg_max(input_truth, 1)) correct_mean = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) with tf.Session() as sess: init = tf.initialize_all_variables() sess.run(init) for i in range(STEPS): train_necks, train_truth = get_necks( sess, n_classes, image_list, BATCH, 'training', tensor_data, tensor_neck ) sess.run(train_step, feed_dict={input_neck: train_necks, input_truth: train_truth}) if i % 20 == 0 or i + 1 == STEPS: valid_necks, valid_truth = get_necks( sess, n_classes, image_list, BATCH, 'validation', tensor_data, tensor_neck ) validation_accuracy = sess.run(correct_mean, feed_dict={ input_neck: valid_necks, input_truth: valid_truth }) print('Step %d: Validation accuracy on random sampled %d examples =%.lf%%' % (i, BATCH, validation_accuracy * 100)) test_necks, test_truth = get_necks( sess, n_classes, image_list, BATCH, 'testing', tensor_data, tensor_neck ) test_accuracy = sess.run(correct_mean, feed_dict={input_neck: test_necks, input_truth: test_truth}) print('Final test accuracy = %.lf%%' % (test_accuracy * 100)) main()虽然代码逻辑较简便,但是时间运行时间增加了很多,这主要是频繁地读取图片信息,计算特征的原因。