欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

PointNet++分类train.py

程序员文章站 2022-04-19 12:33:33
def train(): with tf.Graph().as_default(): with tf.device('/gpu:'+str(GPU_INDEX)): pointclouds_pl, labels_pl = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT) """ pointclouds_pl= (16,1024,3) label_pl = (16,)...
def train():
    with tf.Graph().as_default():
        with tf.device('/gpu:'+str(GPU_INDEX)):
            pointclouds_pl, labels_pl = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT)
            """
            pointclouds_pl= (16,1024,3)  label_pl = (16,)
            """
            is_training_pl = tf.placeholder(tf.bool, shape=(16))
            """ 
            注意全局step = batch参数以使其最小化;优化器增加batch参数在你每一次训练的时候
            """ 
            batch = tf.get_variable('batch', [],
                                    initializer=tf.constant_initializer(0), trainable=False)
            bn_decay = get_bn_decay(batch)  #bn_decay批标准化衰减
            tf.summary.scalar('bn_decay', bn_decay)

            # Get model and loss 
            pred, end_points = MODEL.get_model(pointclouds_pl, is_training_pl, bn_decay=bn_decay)
            """
            模型训练得到pred = (16,40),每个实例分为40个分数, end_points = (16,1024,3)为原始点云数据
            """
            MODEL.get_loss(pred, labels_pl, end_points)
            """ 计算损失 """
            losses = tf.get_collection('losses') 
            """ 返回losses列表 """
            total_loss = tf.add_n(losses, name='total_loss')
            """ 将losses列表中所有值全部相加 """
            tf.summary.scalar('total_loss', total_loss)
            for l in losses + [total_loss]:
                tf.summary.scalar(l.op.name, l)

            correct = tf.equal(tf.argmax(pred, 1), tf.to_int64(labels_pl))
            """ 每一次训练后的预测分类正确的个数"""
            accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE)
            """" 单个批次训练后的正确率"""
            tf.summary.scalar('accuracy', accuracy)

            print ("--- Get training operator")
            """ 获得训练操作 """
            learning_rate = get_learning_rate(batch)
            tf.summary.scalar('learning_rate', learning_rate)
            if OPTIMIZER == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=MOMENTUM)
            elif OPTIMIZER == 'adam':
                optimizer = tf.train.AdamOptimizer(learning_rate)
            train_op = optimizer.minimize(total_loss, global_step=batch)
            """ 梯度优化最小化"""
            
            # Add ops to save and restore all the variables.
            saver = tf.train.Saver()

PointNet++的训练函数,进行构造训练网络的图结构,输入point和label占位符,进行训练模型–>计算损失–>优化器优化–>模型保存。同时每个批次训练完后,计算分类正确的实例个数和分类正确率。

# Create a session
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        config.allow_soft_placement = True
        config.log_device_placement = False
        sess = tf.Session(config=config)
        """ 创建会话 """
        # Add summary writers
        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'), sess.graph)
        test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test'), sess.graph)
       
        # Init variables
        init = tf.global_variables_initializer()
        sess.run(init)
        """ 变量初始化"""
ops = {'pointclouds_pl': pointclouds_pl,
               'labels_pl': labels_pl,
               'is_training_pl': is_training_pl,
               'pred': pred,
               'loss': total_loss,
               'train_op': train_op,
               'merged': merged,
               'step': batch,
               'end_points': end_points}

        best_acc = -1
        for epoch in range(MAX_EPOCH):
            log_string('**** EPOCH %03d ****' % (epoch))
            sys.stdout.flush()
             
            train_one_epoch(sess, ops, train_writer)
            eval_one_epoch(sess, ops, test_writer)

            # Save the variables to disk.
            if epoch % 10 == 0:
                save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt"))
                log_string("Model saved in file: %s" % save_path)

字典参数赋值,开始进行每一次训练,每次训练进行train_one_epoch,和eval_one_epoch操作,每训练10次,保存一次模型。

代码不同之处

def train_one_epoch(sess, ops, train_writer):
'''''''''''''''''''''''''''''''''' 参数变量初始化'''''''''''''''''''''''''''''''''''''''''''''
  while TRAIN_DATASET.has_next_batch():
        batch_data, batch_label = TRAIN_DATASET.next_batch(augment=True)
"""上面步骤进行数据预处理,之后进行训练网络等一系列操作,数据处理都写进modelnet_h5_dataset.py"""

train_one_epoch和eval_one_epoch与PointNet中的大同小异,唯一变化的是作者将数据处理部分写进了modelnet_h5_dataset.py中,创建了一个ModelNetH5Dataset()类。

ModelNetH5Dataset():
该类初始化
1.首先h5_files = getDataFiles(self.list_filename),将train_file中的h5文件名每行读出。
2.定义reset()打乱 h5文件

 def reset(self):
        ''' reset order of h5 files '''
        self.file_idxs = np.arange(0, len(self.h5_files))
        """ 创建索引并打乱索引 """
        if self.shuffle: np.random.shuffle(self.file_idxs)
        self.current_data = None
        self.current_label = None
        self.current_file_idx = 0
        self.batch_idx = 0

has_next_batch()

    def has_next_batch(self):
        # TODO: add backend thread to load data
        if (self.current_data is None) or (not self._has_next_batch_in_file()):
            if self.current_file_idx >= len(self.h5_files):
                return False
            self._load_data_file(self._get_data_filename())
            self.batch_idx = 0
            self.current_file_idx += 1
        return self._has_next_batch_in_file()

首先判断当前文件索引是否为0,或者是否文件已经遍历完 (
读取文件中的数据,每次从中读出一个h5文件的数据
self._load_data_file(self._get_data_filename())

self._get_data_filename(),根据之前创建的打乱的文件索引读出文件名 (读文件打乱
self._load_data_file()利用h5py.File(h5_ilename)读出data 和 label 将data 和 label 对应打乱(数据打乱)返回 current_data 和 current_label

.next_batch()

从单个h5文件中提取,按照批次尺寸提取出相应尺寸的物体点云数据用于训练


    def next_batch(self, augment=False):
        ''' returned dimension may be smaller than self.batch_size '''
        start_idx = self.batch_idx * self.batch_size
        end_idx = min((self.batch_idx+1) * self.batch_size, self.current_data.shape[0])
        """从h5文件中分批次提取数据,设置初始索引和结束索引 """
        bsize = end_idx - start_idx
        batch_label = np.zeros((bsize), dtype=np.int32)
        data_batch = self.current_data[start_idx:end_idx, 0:self.npoints, :].copy()
        """ 将批次数据从全部数据中提取出来放进data_batch"""
        label_batch = self.current_label[start_idx:end_idx].copy()
        """ 提取出对应的label放进label_batch中"""
        self.batch_idx += 1
        if augment: data_batch = self._augment_batch_data(data_batch)
        """扩充数据,增加旋转和扰动的点云数据"""
        return data_batch, label_batch 

本文地址:https://blog.csdn.net/CSDNcylinux/article/details/107149540