PointNet++分类train.py
def train():
with tf.Graph().as_default():
with tf.device('/gpu:'+str(GPU_INDEX)):
pointclouds_pl, labels_pl = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT)
"""
pointclouds_pl= (16,1024,3) label_pl = (16,)
"""
is_training_pl = tf.placeholder(tf.bool, shape=(16))
"""
注意全局step = batch参数以使其最小化;优化器增加batch参数在你每一次训练的时候
"""
batch = tf.get_variable('batch', [],
initializer=tf.constant_initializer(0), trainable=False)
bn_decay = get_bn_decay(batch) #bn_decay批标准化衰减
tf.summary.scalar('bn_decay', bn_decay)
# Get model and loss
pred, end_points = MODEL.get_model(pointclouds_pl, is_training_pl, bn_decay=bn_decay)
"""
模型训练得到pred = (16,40),每个实例分为40个分数, end_points = (16,1024,3)为原始点云数据
"""
MODEL.get_loss(pred, labels_pl, end_points)
""" 计算损失 """
losses = tf.get_collection('losses')
""" 返回losses列表 """
total_loss = tf.add_n(losses, name='total_loss')
""" 将losses列表中所有值全部相加 """
tf.summary.scalar('total_loss', total_loss)
for l in losses + [total_loss]:
tf.summary.scalar(l.op.name, l)
correct = tf.equal(tf.argmax(pred, 1), tf.to_int64(labels_pl))
""" 每一次训练后的预测分类正确的个数"""
accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE)
"""" 单个批次训练后的正确率"""
tf.summary.scalar('accuracy', accuracy)
print ("--- Get training operator")
""" 获得训练操作 """
learning_rate = get_learning_rate(batch)
tf.summary.scalar('learning_rate', learning_rate)
if OPTIMIZER == 'momentum':
optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=MOMENTUM)
elif OPTIMIZER == 'adam':
optimizer = tf.train.AdamOptimizer(learning_rate)
train_op = optimizer.minimize(total_loss, global_step=batch)
""" 梯度优化最小化"""
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
PointNet++的训练函数,进行构造训练网络的图结构,输入point和label占位符,进行训练模型–>计算损失–>优化器优化–>模型保存。同时每个批次训练完后,计算分类正确的实例个数和分类正确率。
# Create a session
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True
config.log_device_placement = False
sess = tf.Session(config=config)
""" 创建会话 """
# Add summary writers
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'), sess.graph)
test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test'), sess.graph)
# Init variables
init = tf.global_variables_initializer()
sess.run(init)
""" 变量初始化"""
ops = {'pointclouds_pl': pointclouds_pl,
'labels_pl': labels_pl,
'is_training_pl': is_training_pl,
'pred': pred,
'loss': total_loss,
'train_op': train_op,
'merged': merged,
'step': batch,
'end_points': end_points}
best_acc = -1
for epoch in range(MAX_EPOCH):
log_string('**** EPOCH %03d ****' % (epoch))
sys.stdout.flush()
train_one_epoch(sess, ops, train_writer)
eval_one_epoch(sess, ops, test_writer)
# Save the variables to disk.
if epoch % 10 == 0:
save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt"))
log_string("Model saved in file: %s" % save_path)
字典参数赋值,开始进行每一次训练,每次训练进行train_one_epoch,和eval_one_epoch操作,每训练10次,保存一次模型。
代码不同之处
def train_one_epoch(sess, ops, train_writer):
'''''''''''''''''''''''''''''''''' 参数变量初始化'''''''''''''''''''''''''''''''''''''''''''''
while TRAIN_DATASET.has_next_batch():
batch_data, batch_label = TRAIN_DATASET.next_batch(augment=True)
"""上面步骤进行数据预处理,之后进行训练网络等一系列操作,数据处理都写进modelnet_h5_dataset.py"""
train_one_epoch和eval_one_epoch与PointNet中的大同小异,唯一变化的是作者将数据处理部分写进了modelnet_h5_dataset.py中,创建了一个ModelNetH5Dataset()类。
ModelNetH5Dataset():
该类初始化
1.首先h5_files = getDataFiles(self.list_filename),将train_file中的h5文件名每行读出。
2.定义reset()打乱 h5文件
def reset(self):
''' reset order of h5 files '''
self.file_idxs = np.arange(0, len(self.h5_files))
""" 创建索引并打乱索引 """
if self.shuffle: np.random.shuffle(self.file_idxs)
self.current_data = None
self.current_label = None
self.current_file_idx = 0
self.batch_idx = 0
has_next_batch()
def has_next_batch(self):
# TODO: add backend thread to load data
if (self.current_data is None) or (not self._has_next_batch_in_file()):
if self.current_file_idx >= len(self.h5_files):
return False
self._load_data_file(self._get_data_filename())
self.batch_idx = 0
self.current_file_idx += 1
return self._has_next_batch_in_file()
首先判断当前文件索引是否为0,或者是否文件已经遍历完 (否)
读取文件中的数据,每次从中读出一个h5文件的数据
self._load_data_file(self._get_data_filename())
self._get_data_filename(),根据之前创建的打乱的文件索引读出文件名 (读文件打乱)
self._load_data_file()利用h5py.File(h5_ilename)读出data 和 label 将data 和 label 对应打乱(数据打乱)返回 current_data 和 current_label
.next_batch()
从单个h5文件中提取,按照批次尺寸提取出相应尺寸的物体点云数据用于训练
def next_batch(self, augment=False):
''' returned dimension may be smaller than self.batch_size '''
start_idx = self.batch_idx * self.batch_size
end_idx = min((self.batch_idx+1) * self.batch_size, self.current_data.shape[0])
"""从h5文件中分批次提取数据,设置初始索引和结束索引 """
bsize = end_idx - start_idx
batch_label = np.zeros((bsize), dtype=np.int32)
data_batch = self.current_data[start_idx:end_idx, 0:self.npoints, :].copy()
""" 将批次数据从全部数据中提取出来放进data_batch"""
label_batch = self.current_label[start_idx:end_idx].copy()
""" 提取出对应的label放进label_batch中"""
self.batch_idx += 1
if augment: data_batch = self._augment_batch_data(data_batch)
"""扩充数据,增加旋转和扰动的点云数据"""
return data_batch, label_batch
本文地址:https://blog.csdn.net/CSDNcylinux/article/details/107149540