欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Unet项目解析(1): run_training.py

程序员文章站 2022-03-17 14:53:57
...

项目GitHub主页:https://github.com/orobix/retina-unet

参考论文:Retina blood vessel segmentation with a convolution neural network (U-net) Retina blood vessel segmentation with a convolution neural network (U-net)

Unet项目解析(1): run_training.py

1. run_training.py解析

import os, sys # os模块中主要用于处理文件和目录 
import configparser # Python 3.6中 configparser全使用小写

#config file to read from
config = configparser.RawConfigParser() 
config.readfp(open(r'./configuration.txt')) # 建议使用 config.read('configuration.txt') #'configuration.txt'的内容见下面
#===========================================
#name of the experiment
name_experiment = config.get('experiment name', 'name')
nohup = config.getboolean('training settings', 'nohup')   #std output on log file?

run_GPU = '' if sys.platform == 'win32' else ' THEANO_FLAGS=device=gpu,floatX=float32 ' #是否用GPU进行训练

#create a folder for the results 创建文件夹用于保存结果
result_dir = name_experiment
print ("\n 1. Create directory for the results (if not already existing)")
if os.path.exists(result_dir):
    print ("Dir already existing")  # 用于保存结果的test文件夹如果存在就没有必要创建
elif sys.platform=='win32':
    os.system('mkdir ' + result_dir)
else:
    os.system('mkdir -p ' +result_dir) # 需要时创建上层目录,如目录早已存在则不当作错误

print ("copy the configuration file in the results folder")
if sys.platform=='win32':
    os.system('copy configuration.txt .\\' +name_experiment+'\\'+name_experiment+'_configuration.txt')
else:
    os.system('cp configuration.txt ./' +name_experiment+'/'+name_experiment+'_configuration.txt')

# run the experiment
if nohup: #作者采用不挂断的方式运行命令
    print ("\n2. Run the training with nohup, no GPU ")
    os.system(' nohup python -u ./src/retinaNN_training.py > ' +'./'+name_experiment+'/'+name_experiment+'_training.nohup') #运行retina_training.py文件
else:
    print ("\n2. Run the training(no nohup), no GPU")
    os.system(' python ./src/retinaNN_training.py') # 采用挂起的形式运行命令

配置文件:configuration.txt (使用的是section-option方法,可以利用字符串匹配进行参数解析)

[data paths] #数据路径 以及 训练集 测试集的名字
path_local =  ./DRIVE_datasets_training_testing/         #封装好的训练集图像+金标准  和  测试集图像+金标准
train_imgs_original = DRIVE_dataset_imgs_train.hdf5      #封装好的训练集图像
train_groundTruth = DRIVE_dataset_groundTruth_train.hdf5 #封装好的训练集金标准

train_border_masks = DRIVE_dataset_borderMasks_train.hdf5sks #封装好的训练集掩膜test_imgs_original = DRIVE_dataset_imgs_test.hdf5 #封装好的测试集图像test_groundTruth = DRIVE_dataset_groundTruth_test.hdf5 #封装好的测试集金标准test_border_masks = DRIVE_dataset_borderMasks_test.hdf5 #封装好的测试集掩膜[experiment name]name = test[data attributes]# 作者的训练集不多 所以作者采用了分块进行训练的方式,从图像中裁剪的图像块大小为 patch_height*patch_widthpatch_height = 48 patch_width = 48[training settings]#number of total patches:N_subimgs = 190000#if patches are extracted only inside the field of view:inside_FOV = False#Number of training epochsN_epochs = 150batch_size = 32#if running with nohup #不挂断地运行命令nohup = True[testing settings]#Choose the model to test: best==epoch with min loss, last==last epochbest_last = best#number of full images for the test (max 20)full_images_to_test = 20#How many original-groundTruth-prediction images are visualized in each imageN_group_visual = 1#Compute average in the prediction, improve results but require more patches to be predictedaverage_mode = True#Only if average_mode==True. Stride for patch extraction, lower value require more patches to be predictedstride_height = 5stride_width = 5#if running with nohupnohup = False