欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块

程序员文章站 2022-03-16 19:26:03
...

由于最近目标是完成基于深度学习的脑肿瘤语义分割实验,所以需要用到自定义的数据载入,本文参考了一下博客:https://blog.csdn.net/tuiqdymy/article/details/84779716?utm_source=app,一开始是做的眼底图像分割,由于使用的是DRIVE数据集,所以数据量很少,之前也是按照上面这篇博客标注了关于图片id的txt文件,但是这次是应用在kaggle脑肿瘤数据集上,kaggle脑肿瘤数据集百度云下载连接:链接:https://pan.baidu.com/s/12RTIv-RqEZwYCm27Im2Djw  提取码:tave  数据量挺大,再完全按照上面博客的方法实现对数据的载入显然不现实,所以就自己稍加修改,记录一下自己的学习过程。

首先我们可以看一下数据存储的结构叭:

pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块

 pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块

上面的每一个文件夹代表着每一个病人的脑肿瘤数据,每一个文件夹下的图片数据也大约有20几张数据(包含着原始图片和金标准图片),每个文件夹中的图片数据不一定相等。


下面就来看看数据加载模块的代码叭。

import torch
import os, glob
import random, csv
import matplotlib.pylab as plt
import torchvision
from PIL import Image
from  torch.utils.data import Dataset
import numpy as np


class  driveDateset(Dataset):
    def __init__(self, root, ignore_label=255):
        super(driveDateset,self).__init__()


        self.root = root

        self.files = []

        for file in os.listdir(self.root):
            fil = os.path.join(self.root, file)
            for data1 in os.listdir(fil):
                data1_split = data1.split('.')
                data11_split = data1_split[0].split('_')
                for data2 in os.listdir(fil):
                    data2_split = data2.split('.')
                    data22_split = data2_split[0].split('_')
                    if (data11_split[-1]==data22_split[-2]) & (data22_split[-1]=='mask'):
                        img_file = os.path.join(fil,data1)
                        label_file = os.path.join(fil,data2)
                        self.files.append({
                            "img":img_file,
                            "label":label_file,
                            "name":data1_split[0]
                        })

    #返回数据集大小
    def __len__(self):
        return len(self.files)

    #实现数据的下标索引
    def  __getitem__(self, index):
        dataflies = self.files[index]

        '''load the data '''
        name = dataflies["name"]
        image = Image.open(dataflies["img"]).convert('RGB')
        label = Image.open(dataflies["label"]).convert('L')
        size_origin = image.size # w*h

        I = np.asarray(image, np.float32)

        I = I.transpose((2, 0, 1))  # transpose the  H*W*C to C*H*W
        L = np.asarray(np.array(label), np.int64)
        # print(I.shape,L.shape)
        return I.copy(), L.copy(), np.array(size_origin), name

然后我们可以debug一下,看看代码中的各个参数的具体赋值情况,便于理解再次修改应用到自己所需数据集中。

pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块

我们来看看测试的运行效果图:

pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块

上面是原始图片,下面四张图片是对应的金标准图片。

pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块

pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块

pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块

pytorch 语义分割-医学图像-脑肿瘤数据集的载入模块

到这里就测试完毕。下面是测试部分代码:

if __name__ == '__main__':
    DATA_DIRECTORY = "F:\\experiment_code\\U-net_brain\\kaggle_3m\\train"
    Batch_size = 4
    MEAN = (104.008, 116.669, 122.675)
    dst = driveDateset(DATA_DIRECTORY)
    # just for test,  so the mean is (0,0,0) to show the original images.
    # But when we are training a model, the mean should have another value
    trainloader = torch.utils.data.DataLoader(dst, batch_size=Batch_size)
    plt.ion()
    for i, data in enumerate(trainloader):
        imgs, labels, _, _ = data
        if i % 1 == 0:
            img = torchvision.utils.make_grid(imgs).numpy()
            img = img.astype(np.uint8)  # change the dtype from float32 to uint8,
                    # because the plt.imshow() need the uint8
            img = np.transpose(img, (1, 2, 0))  # transpose the C*H*W to H*W*C
            plt.imshow(img)
            plt.show()
            plt.pause(0.5)

                #label = torchvision.utils.make_grid(labels).numpy()
            labels = labels.numpy().astype(np.uint8)  # change the dtype from float32 to uint8,
            #                                       # because the plt.imshow() need the uint8
            for i in range(labels.shape[0]):
                plt.imshow(labels[i], cmap='gray')
                plt.show()
                plt.pause(0.5)

新手上车深度学习,写的不太完美,但是希望这篇博客对大家有用鸭。

相关标签: 深度学习