Siamese代码解析之训练及测试
程序员文章站
2024-03-15 12:32:29
...
记录下代码,分享给大家。主要目的:以后和别人相亲的时候能有底气地多给对方一篇博客作为彩礼。代码比较短,但还是花了一段时间理解,还没测试,不知道有没有bug。
参考链接:PyTorch 实现孪生网络识别面部相似度-PyTorch 中文网 https://www.pytorchtutorial.com/pytorch-one-shot-learning/#i-4
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jan 24 10:00:24 2018
Paper: Siamese Neural Networks for One-shot Image Recognition
links: https://www.cnblogs.com/denny402/p/7520063.html
"""
import torch
from torch.autograd import Variable
import os
import random
import linecache
import numpy as np
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import PIL.ImageOps
import matplotlib.pyplot as plt
root = r'E:/siamese/faces'
train_txt_root = r'E:/siamese/faces/train.txt'
test_txt_root = r'E:/siamese/faces/test.txt'
train_batch_size = 32
train_number_epochs = 100
# 图片可视化函数
def imshow(img, text=None, should_save=False):
npimg = img.numpy()
plt.axis("off")
if text:
plt.text(75, 8, text, style='italic', fontweight='bold',
bbox={'facecolor': 'white', 'alpha': 0.8, 'pad': 10})
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
def show_plot(iteration, loss):
plt.plot(iteration, loss)
plt.show()
def convert(train=True):
if (train):
f = open(train_txt_root, 'w') #将图片地址写在/data_faces/train.txt中
data_path = root
if (not os.path.exists(data_path)):
os.makedirs(data_path)
for i in range(40):
for j in range(8):
img_path = data_path + '/s' + str(i + 1) + '/' + str(j + 1) + '.pgm'
f.write(img_path + ' ' + str(i) + '\n')
f.close()
else:
f = open(test_txt_root, 'w') # 将图片地址写在/data_faces/train.txt中
data_path = root
if (not os.path.exists(data_path)):
os.makedirs(data_path)
for i in range(40):
for j in range(9,11):
img_path = data_path + '/s' + str(i + 1) + '/' + str(j) + '.pgm'
f.write(img_path + ' ' + str(i) + '\n')
f.close()
# ready the dataset, Not use ImageFolder as the author did
class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, should_invert=False):
self.transform = transform
self.target_transform = target_transform
self.should_invert = should_invert
self.txt = txt
def __getitem__(self, index):
line = linecache.getline(self.txt, random.randint(1, self.__len__())) #(1,400)
line.strip('\n')
img0_list = line.split()
should_get_same_class = random.randint(0, 1)
if should_get_same_class: #找到相同的类别为止
while True:
img1_list = linecache.getline(self.txt, random.randint(1, self.__len__())).strip('\n').split()
if img0_list[1] == img1_list[1]:
break
else:
img1_list = linecache.getline(self.txt, random.randint(1, self.__len__())).strip('\n').split()
img0 = Image.open(img0_list[0])
img1 = Image.open(img1_list[0])
img0 = img0.convert("L") #转灰度图像
img1 = img1.convert("L")
if self.should_invert: #默认不转灰度图像
img0 = PIL.ImageOps.invert(img0)
img1 = PIL.ImageOps.invert(img1)
if self.transform is not None:
img0 = self.transform(img0)
img1 = self.transform(img1)
#两张图片不等为1,相等为0
return img0, img1, torch.from_numpy(np.array([int(img1_list[1] != img0_list[1])], dtype=np.float32))
def __len__(self):
fh = open(self.txt, 'r')
num = len(fh.readlines())
fh.close()
return num
# Visualising some of the data
"""
train_data=MyDataset(txt = Config.txt_root, transform=transforms.ToTensor(),
transform=transforms.Compose([transforms.Scale((100,100)),
transforms.ToTensor()], should_invert=False))
train_loader = DataLoader(dataset=train_data, batch_size=8, shuffle=True)
#it = iter(train_loader)
p1, p2, label = it.next()
example_batch = it.next()
concatenated = torch.cat((example_batch[0],example_batch[1]),0)
imshow(torchvision.utils.make_grid(concatenated))
print(example_batch[2].numpy())
"""
# Neural Net Definition, Standard CNNs
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class SiameseNetwork(nn.Module):
def __init__(self):
super(SiameseNetwork, self).__init__()
self.cnn1 = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(1, 4, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(4),
nn.Dropout2d(p=.2),
nn.ReflectionPad2d(1),
nn.Conv2d(4, 8, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8),
nn.Dropout2d(p=.2),
nn.ReflectionPad2d(1),
nn.Conv2d(8, 8, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8),
nn.Dropout2d(p=.2),
)
self.fc1 = nn.Sequential(
nn.Linear(8 * 100 * 100, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 5)
)
def forward_once(self, x):
output = self.cnn1(x)
output = output.view(output.size()[0], -1)
output = self.fc1(output)
return output
def forward(self, input1, input2):
output1 = self.forward_once(input1)
output2 = self.forward_once(input2) #输出32*5维度
return output1, output2
# Custom Contrastive Loss
class ContrastiveLoss(torch.nn.Module):
"""
Contrastive loss function.
Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
"""
def __init__(self, margin=2.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, output1, output2, label):
euclidean_distance = F.pairwise_distance(output1, output2)
loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
(label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
return loss_contrastive
def trian(net):
counter = []
loss_history = []
iteration_number = 0
for epoch in range(0, train_number_epochs):
for i, data in enumerate(train_dataloader, 0): #定义下标从0开始
img0, img1, label = data
#img0, img1, label = Variable(img0), Variable(img1), Variable(label)
output1, output2 = net(img0, img1)
optimizer.zero_grad()
loss_contrastive = criterion.forward(output1, output2, label)
loss_contrastive.backward()
optimizer.step()
if i % 10 == 0:
print("Epoch:{}, Current loss {}\n".format(epoch, loss_contrastive.item()))
iteration_number += 10
counter.append(iteration_number)
loss_history.append(loss_contrastive.item())
show_plot(counter, loss_history)
if __name__=='__main__':
# Training num_workers=2,线程
train_data = MyDataset(txt=train_txt_root, transform=transforms.Compose(
[transforms.Resize((100, 100)), transforms.ToTensor()]), should_invert=False)
train_dataloader = DataLoader(dataset=train_data, shuffle=True,
batch_size=train_batch_size) # 对train_data进行batch块分区
print(type(train_data))
print(type(train_dataloader))
test_data = MyDataset(txt=test_txt_root, transform=transforms.Compose(
[transforms.Resize((100, 100)), transforms.ToTensor()]), should_invert=False)
test_dataloader = DataLoader(dataset=test_data, shuffle=True, batch_size=1)
net = SiameseNetwork()
criterion = ContrastiveLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0005)
trian(net)
torch.save(net.state_dict(), 'params.pth')
##test
dataiter = iter(test_dataloader)
x0, _, _ = next(dataiter)
for i in range(10):
_, x1, label2 = next(dataiter)
print('第'+str(i)+'次')
concatenated = torch.cat((x0, x1), 0)
output1, output2 = net(x0,x1)
print(output1)
print(output2)
euclidean_distance = F.pairwise_distance(output1, output2)
print(euclidean_distance)
imshow(torchvision.utils.make_grid(concatenated), 'Dissimilarity: {:.2f}'.format(euclidean_distance.item()))
文档结构
上一篇: Java学习58:编写泛型
下一篇: [蓝桥杯]猜年龄(Python 实现)