欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Siamese代码解析之训练及测试

程序员文章站 2024-03-15 12:32:29
...

记录下代码,分享给大家。主要目的:以后和别人相亲的时候能有底气地多给对方一篇博客作为彩礼。代码比较短,但还是花了一段时间理解,还没测试,不知道有没有bug。

参考链接:PyTorch 实现孪生网络识别面部相似度-PyTorch 中文网  https://www.pytorchtutorial.com/pytorch-one-shot-learning/#i-4

 

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 24 10:00:24 2018
Paper: Siamese Neural Networks for One-shot Image Recognition
links: https://www.cnblogs.com/denny402/p/7520063.html
"""
import torch
from torch.autograd import Variable
import os
import random
import linecache
import numpy as np
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import PIL.ImageOps
import matplotlib.pyplot as plt


root = r'E:/siamese/faces'
train_txt_root = r'E:/siamese/faces/train.txt'
test_txt_root = r'E:/siamese/faces/test.txt'

train_batch_size = 32
train_number_epochs = 100


# 图片可视化函数
def imshow(img, text=None, should_save=False):
    npimg = img.numpy()
    plt.axis("off")
    if text:
        plt.text(75, 8, text, style='italic', fontweight='bold',
                 bbox={'facecolor': 'white', 'alpha': 0.8, 'pad': 10})
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


def show_plot(iteration, loss):
    plt.plot(iteration, loss)
    plt.show()


def convert(train=True):
    if (train):
        f = open(train_txt_root, 'w')   #将图片地址写在/data_faces/train.txt中
        data_path = root
        if (not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i in range(40):
            for j in range(8):
                img_path = data_path + '/s' + str(i + 1) + '/' + str(j + 1) + '.pgm'
                f.write(img_path + ' ' + str(i) + '\n')
        f.close()
    else:
        f = open(test_txt_root, 'w')  # 将图片地址写在/data_faces/train.txt中
        data_path = root
        if (not os.path.exists(data_path)):
            os.makedirs(data_path)
        for i in range(40):
            for j in range(9,11):
                img_path = data_path + '/s' + str(i + 1) + '/' + str(j) + '.pgm'
                f.write(img_path + ' ' + str(i) + '\n')
        f.close()

# ready the dataset, Not use ImageFolder as the author did
class MyDataset(Dataset):

    def __init__(self, txt, transform=None, target_transform=None, should_invert=False):

        self.transform = transform
        self.target_transform = target_transform
        self.should_invert = should_invert
        self.txt = txt

    def __getitem__(self, index):

        line = linecache.getline(self.txt, random.randint(1, self.__len__())) #(1,400)
        line.strip('\n')
        img0_list = line.split()
        should_get_same_class = random.randint(0, 1)
        if should_get_same_class:      #找到相同的类别为止
            while True:
                img1_list = linecache.getline(self.txt, random.randint(1, self.__len__())).strip('\n').split()
                if img0_list[1] == img1_list[1]:
                    break
        else:
            img1_list = linecache.getline(self.txt, random.randint(1, self.__len__())).strip('\n').split()

        img0 = Image.open(img0_list[0])
        img1 = Image.open(img1_list[0])
        img0 = img0.convert("L")    #转灰度图像
        img1 = img1.convert("L")

        if self.should_invert:      #默认不转灰度图像
            img0 = PIL.ImageOps.invert(img0)
            img1 = PIL.ImageOps.invert(img1)

        if self.transform is not None:
            img0 = self.transform(img0)
            img1 = self.transform(img1)

         #两张图片不等为1,相等为0
        return img0, img1, torch.from_numpy(np.array([int(img1_list[1] != img0_list[1])], dtype=np.float32))

    def __len__(self):
        fh = open(self.txt, 'r')
        num = len(fh.readlines())
        fh.close()
        return num


# Visualising some of the data
"""
train_data=MyDataset(txt = Config.txt_root, transform=transforms.ToTensor(), 
                     transform=transforms.Compose([transforms.Scale((100,100)),
                               transforms.ToTensor()], should_invert=False))
train_loader = DataLoader(dataset=train_data, batch_size=8, shuffle=True)
#it = iter(train_loader)
p1, p2, label = it.next()
example_batch = it.next()
concatenated = torch.cat((example_batch[0],example_batch[1]),0)
imshow(torchvision.utils.make_grid(concatenated))
print(example_batch[2].numpy())
"""

# Neural Net Definition, Standard CNNs
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(1, 4, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(4),
            nn.Dropout2d(p=.2),

            nn.ReflectionPad2d(1),
            nn.Conv2d(4, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),
            nn.Dropout2d(p=.2),

            nn.ReflectionPad2d(1),
            nn.Conv2d(8, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),
            nn.Dropout2d(p=.2),
        )

        self.fc1 = nn.Sequential(
            nn.Linear(8 * 100 * 100, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 500),
            nn.ReLU(inplace=True),

            nn.Linear(500, 5)
        )

    def forward_once(self, x):
        output = self.cnn1(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)   #输出32*5维度
        return output1, output2


# Custom Contrastive Loss
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))

        return loss_contrastive

def trian(net):
    counter = []
    loss_history = []
    iteration_number = 0
    for epoch in range(0, train_number_epochs):
        for i, data in enumerate(train_dataloader, 0):  #定义下标从0开始
            img0, img1, label = data
            #img0, img1, label = Variable(img0), Variable(img1), Variable(label)
            output1, output2 = net(img0, img1)
            optimizer.zero_grad()
            loss_contrastive = criterion.forward(output1, output2, label)
            loss_contrastive.backward()
            optimizer.step()

            if i % 10 == 0:
                print("Epoch:{},  Current loss {}\n".format(epoch, loss_contrastive.item()))
                iteration_number += 10
                counter.append(iteration_number)
                loss_history.append(loss_contrastive.item())
    show_plot(counter, loss_history)

if __name__=='__main__':
    # Training   num_workers=2,线程
    train_data = MyDataset(txt=train_txt_root, transform=transforms.Compose(
        [transforms.Resize((100, 100)), transforms.ToTensor()]), should_invert=False)
    train_dataloader = DataLoader(dataset=train_data, shuffle=True,
                                  batch_size=train_batch_size)  # 对train_data进行batch块分区
    print(type(train_data))
    print(type(train_dataloader))
    test_data = MyDataset(txt=test_txt_root, transform=transforms.Compose(
        [transforms.Resize((100, 100)), transforms.ToTensor()]), should_invert=False)
    test_dataloader = DataLoader(dataset=test_data, shuffle=True, batch_size=1)

    net = SiameseNetwork()
    criterion = ContrastiveLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0005)
    trian(net)

    torch.save(net.state_dict(), 'params.pth')

    ##test
    dataiter = iter(test_dataloader)
    x0, _, _ = next(dataiter)

    for i in range(10):
        _, x1, label2 = next(dataiter)
        print('第'+str(i)+'次')

        concatenated = torch.cat((x0, x1), 0)
        output1, output2 = net(x0,x1)
        print(output1)
        print(output2)
        euclidean_distance = F.pairwise_distance(output1, output2)
        print(euclidean_distance)
        imshow(torchvision.utils.make_grid(concatenated), 'Dissimilarity: {:.2f}'.format(euclidean_distance.item()))

文档结构 

Siamese代码解析之训练及测试