pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块
程序员文章站
2022-03-16 19:26:03
...
由于最近目标是完成基于深度学习的脑肿瘤语义分割实验,所以需要用到自定义的数据载入,本文参考了一下博客:https://blog.csdn.net/tuiqdymy/article/details/84779716?utm_source=app,一开始是做的眼底图像分割,由于使用的是DRIVE数据集,所以数据量很少,之前也是按照上面这篇博客标注了关于图片id的txt文件,但是这次是应用在kaggle脑肿瘤数据集上,kaggle脑肿瘤数据集百度云下载连接:链接:https://pan.baidu.com/s/12RTIv-RqEZwYCm27Im2Djw 提取码:tave 数据量挺大,再完全按照上面博客的方法实现对数据的载入显然不现实,所以就自己稍加修改,记录一下自己的学习过程。
首先我们可以看一下数据存储的结构叭:
上面的每一个文件夹代表着每一个病人的脑肿瘤数据,每一个文件夹下的图片数据也大约有20几张数据(包含着原始图片和金标准图片),每个文件夹中的图片数据不一定相等。
下面就来看看数据加载模块的代码叭。
import torch
import os, glob
import random, csv
import matplotlib.pylab as plt
import torchvision
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
class driveDateset(Dataset):
def __init__(self, root, ignore_label=255):
super(driveDateset,self).__init__()
self.root = root
self.files = []
for file in os.listdir(self.root):
fil = os.path.join(self.root, file)
for data1 in os.listdir(fil):
data1_split = data1.split('.')
data11_split = data1_split[0].split('_')
for data2 in os.listdir(fil):
data2_split = data2.split('.')
data22_split = data2_split[0].split('_')
if (data11_split[-1]==data22_split[-2]) & (data22_split[-1]=='mask'):
img_file = os.path.join(fil,data1)
label_file = os.path.join(fil,data2)
self.files.append({
"img":img_file,
"label":label_file,
"name":data1_split[0]
})
#返回数据集大小
def __len__(self):
return len(self.files)
#实现数据的下标索引
def __getitem__(self, index):
dataflies = self.files[index]
'''load the data '''
name = dataflies["name"]
image = Image.open(dataflies["img"]).convert('RGB')
label = Image.open(dataflies["label"]).convert('L')
size_origin = image.size # w*h
I = np.asarray(image, np.float32)
I = I.transpose((2, 0, 1)) # transpose the H*W*C to C*H*W
L = np.asarray(np.array(label), np.int64)
# print(I.shape,L.shape)
return I.copy(), L.copy(), np.array(size_origin), name
然后我们可以debug一下,看看代码中的各个参数的具体赋值情况,便于理解再次修改应用到自己所需数据集中。
我们来看看测试的运行效果图:
上面是原始图片,下面四张图片是对应的金标准图片。
到这里就测试完毕。下面是测试部分代码:
if __name__ == '__main__':
DATA_DIRECTORY = "F:\\experiment_code\\U-net_brain\\kaggle_3m\\train"
Batch_size = 4
MEAN = (104.008, 116.669, 122.675)
dst = driveDateset(DATA_DIRECTORY)
# just for test, so the mean is (0,0,0) to show the original images.
# But when we are training a model, the mean should have another value
trainloader = torch.utils.data.DataLoader(dst, batch_size=Batch_size)
plt.ion()
for i, data in enumerate(trainloader):
imgs, labels, _, _ = data
if i % 1 == 0:
img = torchvision.utils.make_grid(imgs).numpy()
img = img.astype(np.uint8) # change the dtype from float32 to uint8,
# because the plt.imshow() need the uint8
img = np.transpose(img, (1, 2, 0)) # transpose the C*H*W to H*W*C
plt.imshow(img)
plt.show()
plt.pause(0.5)
#label = torchvision.utils.make_grid(labels).numpy()
labels = labels.numpy().astype(np.uint8) # change the dtype from float32 to uint8,
# # because the plt.imshow() need the uint8
for i in range(labels.shape[0]):
plt.imshow(labels[i], cmap='gray')
plt.show()
plt.pause(0.5)
新手上车深度学习,写的不太完美,但是希望这篇博客对大家有用鸭。