欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

戴口罩的人脸识别

程序员文章站 2022-03-20 21:57:18
...

实现思路:

使用的是武汉大学公开的口罩遮挡人脸数据集,下载地址https://github.com/X-zhangyang/Real-World-Masked-Face-Dataset。其中有一类是真实人脸口罩遮挡(多为明星),如图所示为白敬亭带着真口罩。需要注意的是每个明星的照片很少只有两三个,所以我没有将它split为训练集和验证集。载入数据集的时候,使用ImageFolder,将文件夹的名字作为标注。选择合适的分类网络进行训练即可。

戴口罩的人脸识别

踩过的坑:

在使用 ImageFolder读取数据集路径的时候,传入参数root的必须是所有待分类的文件夹的上一层文件夹的名字,也即上图中的"~/AFDB_masked_face_dataset",如果写的是"~/AFDB_masked_face_dataset/baijingting"则会报错。 

戴口罩的人脸识别

 可以参考源码中的解释

class ImageFolder(DatasetFolder):
    """A generic data loader where the images are arranged in this way: ::

        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png

        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png

    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.
        is_valid_file (callable, optional): A function that takes path of an Image file
            and check if the file is a valid_file (used to check of corrupt files)

     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """

    def __init__(self, root, transform=None, target_transform=None,
                 loader=default_loader, is_valid_file=None):
        super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
                                          transform=transform,
                                          target_transform=target_transform,
                                          is_valid_file=is_valid_file)
        self.imgs = self.samples

 ②

在使用不同数据集的时候,要注意把网络的num_classes设置为和数据集所对应的,否则在训练和预测的时候都会报错。训练报错:RuntimeError: cuda runtime error (59) : device-side assert triggered at C:/w/1/s/tmp_conda_3.6_045031/conda/conda-bld/pytorch_1565412750030/work/aten/src\THC/generic/THCTensorMath.cu:26。测试报错    size mismatch for classifier.6.weight: copying a param with shape torch.Size([523, 2048]) from checkpoint, the shape in current model is torch.Size([5, 2048]).
 

实现效果:

戴口罩的人脸识别

 网络认为图片是陈法蓉的概率为99.9%