欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

YOLO代码解析(4)

程序员文章站 2022-06-21 16:26:24
...

下面介绍训练和测试代码,训练代码主要graph构建,加载预训练模型,训练中的数据读取和保存相关日志和模型文件等内容,测试代码主要部分是模型预测结果格式的转换。

其他相关的部分请见:
YOLO代码解析(1) 代码总览与使用
YOLO代码解析(2) 数据处理
YOLO代码解析(3) 模型和损失函数
YOLO代码解析(4) 训练和测试代码

训练相关代码:yolo_solver.py

def _train(self):
    """训练模型
    创建优化器,最小化Loss
    Args:
      total_loss: Total loss from net.loss()
      global_step: Integer Variable counting the number of training steps
      processed
    Returns:
      train_op: op for training
    """
    # 使用Momentum优化算法
    opt = tf.train.MomentumOptimizer(self.learning_rate, self.moment)
    grads = opt.compute_gradients(self.total_loss)

    apply_gradient_op = opt.apply_gradients(grads, global_step=self.global_step)

    # 这里也可以直接写成
    # tf.train.MomentumOptimizer(self.learning_rate,self.moment).minimize(self.total_loss,global_step=self.global_step)

    return apply_gradient_op

  def construct_graph(self):
    # 构建graph
    self.global_step = tf.Variable(0, trainable=False)
    # (1)训练时网络的输入
    self.images = tf.placeholder(tf.float32, (self.batch_size, self.height, self.width, 3))
    self.labels = tf.placeholder(tf.float32, (self.batch_size, self.max_objects, 5))
    self.objects_num = tf.placeholder(tf.int32, (self.batch_size))

    # (2)inference部分,输入是一张图片,输出是一个(N,cell_size,cell_size,class_num+box_num*5)的tensor
    self.predicts = self.net.inference(self.images)

    # (3)loss 部分
    self.total_loss = self.net.loss(self.predicts, self.labels, self.objects_num)
    
    tf.summary.scalar('loss', self.total_loss)
    self.train_op = self._train()

  def solve(self):
    saver1 = tf.train.Saver(self.net.pretrained_collection, write_version=1)
    saver2 = tf.train.Saver(self.net.trainable_collection, write_version=1)

    # 变量初始化
    init =  tf.global_variables_initializer()

    summary_op = tf.summary.merge_all()

    sess = tf.Session()
    sess.run(init)

    # 加载预训练模型
    saver1.restore(sess, self.pretrain_path)

    # 创建 event file writer
    summary_writer = tf.summary.FileWriter(self.train_dir, sess.graph)

    for step in range(self.max_iterators):
      start_time = time.time()
      # 获取一个batch的训练数据
      np_images, np_labels, np_objects_num = self.dataset.batch()

      _, loss_value = sess.run([self.train_op, self.total_loss], feed_dict={self.images: np_images, self.labels: np_labels, self.objects_num: np_objects_num})


      duration = time.time() - start_time

      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

      if step % 10 == 0:
        num_examples_per_step = self.dataset.batch_size
        examples_per_sec = num_examples_per_step / duration
        sec_per_batch = float(duration)

        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f sec/batch)')
        print (format_str % (datetime.now(), step, loss_value,examples_per_sec, sec_per_batch))

        sys.stdout.flush()
      if step % 100 == 0: # 保存event file
        summary_str = sess.run(summary_op, feed_dict={self.images: np_images, self.labels: np_labels, self.objects_num: np_objects_num})
        summary_writer.add_summary(summary_str, step)
      if step % 5000 == 0: # 保存checkpoint
        saver2.save(sess, self.train_dir + '/model.ckpt', global_step=step)
    sess.close()

测试相关代码:demo.py

# 对网络给出的预测结果做处理
def process_predicts(predicts):
    # predicts 的shape是 (N,grid_size,grid_size,30), 30=(4+1)*2+20
    p_classes = predicts[0, :, :, 0:20] # 类别的概率
    C = predicts[0, :, :, 20:22]        # Bbox中有物体的概率
    coordinate = predicts[0, :, :, 22:] # 预测的Bbox坐标
    print(predicts.shape)

    p_classes = np.reshape(p_classes, (7, 7, 1, 20))
    C = np.reshape(C, (7, 7, 2, 1))

    # P = 有物体的概率 * 类别的概率
    P = C * p_classes
    print(P.shape)

    # 找到有最大的概率P的Bbox
    index = np.argmax(P)
    index = np.unravel_index(index, P.shape)

    class_num = index[3]

    coordinate = np.reshape(coordinate, (7, 7, 2, 4))

    max_coordinate = coordinate[index[0], index[1], index[2], :]

    # 对网络输出的坐标值进行处理
    # 网络输出的Bbox的中心坐标是相对于格子左上角的坐标,并且用格子的宽度进行归一化(偏移+归一化),这里需要处理成在原图中的坐标
    # 网络输出的Bbox的宽高是相对于图片大小归一化的,这里也要恢复成原始大小
    xcenter = max_coordinate[0]
    ycenter = max_coordinate[1]
    w = max_coordinate[2]
    h = max_coordinate[3]

    # ‘恢复’中心坐标:反偏移,反归一化
    xcenter = (index[1] + xcenter) * (448/7.0)
    ycenter = (index[0] + ycenter) * (448/7.0)
    # ‘恢复’宽高到原始像素大小
    w = w * 448
    h = h * 448

    xmin = xcenter - w/2.0
    ymin = ycenter - h/2.0

    xmax = xmin + w
    ymax = ymin + h

    # 这里检测部分写的比较‘简单’,直接取了物体概率*类别概率最大的那个Bbox和class的结果
    # 实际上应该对每一个类分别进行检测,并用NMS去除多余的候选框
    return xmin, ymin, xmax, ymax, class_num


common_params = {'image_size': 448, 'num_classes': 20, 'batch_size':1}
net_params = {'cell_size': 7, 'boxes_per_cell':2, 'weight_decay': 0.0005}

# network,input place holder and output tensor
net = YoloTinyNet(common_params, net_params, test=True)
image = tf.placeholder(tf.float32, (1, 448, 448, 3))
predicts = net.inference(image)

sess = tf.Session()

# 读入图片
np_img = cv2.imread('cat.jpg')
height, width, channels = np_img.shape
print(height, width, channels)


# 对图片作处理,尺寸缩放,值映射到[-1,1]
resized_img = cv2.resize(np_img, (448, 448))
np_img = cv2.cvtColor(resized_img, cv2.COLOR_BGR2RGB)
np_img = np_img.astype(np.float32)
np_img = np_img / 255.0 * 2 - 1
np_img = np.reshape(np_img, (1, 448, 448, 3))

# 加载模型,并做前向传播得到检测结果
saver = tf.train.Saver()
saver.restore(sess, 'models/pretrain/yolo_tiny.ckpt')
np_predict = sess.run(predicts, feed_dict={image: np_img})

xmin, ymin, xmax, ymax, class_num = process_predicts(np_predict)
class_name = classes_name[class_num]
cv2.rectangle(resized_img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (0, 0, 255))
cv2.putText(resized_img, class_name, (int(xmin), int(ymin)), 2, 1.5, (0, 0, 255))
cv2.imwrite('cat_out.jpg', resized_img)
sess.close()

其他没有提到的部分代码请见完整代码
另外对代码中涉及到的一些TensorFlow的函数的使用做了一个简单的整理,详见tensorflow函数