欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

记录一个数据集处理的类

程序员文章站 2022-07-12 21:37:24
...

加载数据集和相应标签

熟悉一下预处理的功能

import os
import cv2


class Mydata():

    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir  # 所在的文件夹
        self.label_dir = label_dir  # 标签名称
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)  # 返回的是path目录下所以文件名称的列表

    def __getitem__(self, idx):
        img_name = self.img_path[idx]  # 列表加索引对应的就是文件名称
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)  # 三个拼接出来的就是文件路径
        img = cv2.imread(img_item_path)
        label = self.label_dir
        return img, label

    def __len__(self):
        return len(self.img_path)
if __name__ == '__main__':

    roor_dir='dataset/train'
    ants_label_dir='ants'
    ants_dataset=Mydata(roor_dir,ants_label_dir)
    ants_img,ants_label=ants_dataset[0]
    cv2.imshow('1',ants_img)
    cv2.waitKey(0)
    print(ants_dataset.__len__()) # 打印长度