记录一个数据集处理的类
程序员文章站
2022-07-12 21:37:24
...
加载数据集和相应标签
熟悉一下预处理的功能
import os
import cv2
class Mydata():
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir # 所在的文件夹
self.label_dir = label_dir # 标签名称
self.path = os.path.join(self.root_dir, self.label_dir)
self.img_path = os.listdir(self.path) # 返回的是path目录下所以文件名称的列表
def __getitem__(self, idx):
img_name = self.img_path[idx] # 列表加索引对应的就是文件名称
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 三个拼接出来的就是文件路径
img = cv2.imread(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
if __name__ == '__main__':
roor_dir='dataset/train'
ants_label_dir='ants'
ants_dataset=Mydata(roor_dir,ants_label_dir)
ants_img,ants_label=ants_dataset[0]
cv2.imshow('1',ants_img)
cv2.waitKey(0)
print(ants_dataset.__len__()) # 打印长度