【Pytorch】用自己训练的resnet18模型进行推理
程序员文章站
2022-03-27 10:08:04
在网上看了这篇博客:【Pytorch】使用ResNet-50迁移学习进行图像分类训练https://blog.csdn.net/heiheiya/article/details/103028543有的小伙伴就有疑问,模型训练好后,怎么进行推理。于是,我写了这篇关于使用自己训练的resnet18模型进行推理,关于训练的部分请参考上面那篇。# *_* coding : UTF-8 *_*# 开发人员: csu·pan-_-||# 开发时间: 2020/12/29 19:16# 文件名称: re...
在网上看了这篇博客:
【Pytorch】使用ResNet-50迁移学习进行图像分类训练
https://blog.csdn.net/heiheiya/article/details/103028543
有的小伙伴就有疑问,模型训练好后,怎么进行推理。于是,我写了这篇关于使用自己训练的resnet18模型进行推理,训练的部分请参考上面那篇。
# *_* coding : UTF-8 *_*
# 开发人员: csu·pan-_-||
# 开发时间: 2020/12/29 19:16
# 文件名称: resnet_battery_infer.py
# 开发工具: PyCharm
# 功能描述: 用自己训练好的resnet18模型进行推理
import torch
import time
import numpy as np
import os
import cv2
modelPath =r'Battery\battery_resnet18.pt' # 模型路径
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
validPath = r'Battery\valid\ok' # 需要测试的图片路径
txtPath = 'Battery/results' # 结果存储路径
files = os.listdir(validPath) # 展开图片文件列表
model = torch.load(modelPath) # 加载模型
class_name = ['NG','OK'] # 类别名称
classList = [] # 方便计数
with torch.no_grad():
model.eval()
start = time.time()
for j, input in enumerate(files):
onestart = time.time()
img = cv2.imread(os.path.join(validPath,input))
# 转换图片格式
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x800x800
img = np.ascontiguousarray(img)
# 使用cuda
img = torch.from_numpy(img).to(device)
img = img.float() # int转换为float
img /= 255.0 # 归一化
if img.ndimension() == 3:
img = img.unsqueeze(0) # 1x3x800x800
outputs = model(img)
ret, predictions = torch.max(outputs.data, 1)
print('outputs.data: ',outputs.data)
# outputs.data: tensor([[6.4012e-04, 9.9936e-01]], device='cuda:0')
# tensor转换成list:
inferclass = predictions.cpu().numpy().tolist()[0]
classList.append(class_name[inferclass])
print('class: {:s}'.format(class_name[inferclass]))
print('input: ',input)
# class: 0 或 1
oneend = time.time()
print('one last: {:.4f}'.format(oneend-onestart))
with open(txtPath + '/' + str(j) + '.txt', 'w') as f:
f.write(class_name[inferclass]) # 存储结果为 'NG' 或 'OK'
end = time.time()
print('all last: {:.4f}s in {:d} imgs'.format(end - start,len(files)))
# 统计类别数量:
print('NG num: ',classList.count('NG'))
print('OK num: ', classList.count('OK'))
在我的显卡上:TITAN RTX,还是蛮快的,推理一张图大概 14ms
本文地址:https://blog.csdn.net/qq_36563273/article/details/111962287
上一篇: LINUX安装nginx详细步骤
下一篇: 做绿豆糕绿豆一定要泡吗
推荐阅读
-
python 用opencv调用训练好的模型进行识别的方法
-
计算机视觉(3):用inception-v3模型重新训练自己的数据模型
-
【Pytorch】用自己训练的resnet18模型进行推理
-
PyTorch 迁移学习实践(几分钟即可训练好自己的模型)
-
pytorch加载预训练模型与自己模型不匹配的解决方案
-
PaddleDetection——使用自己制作的VOC数据集进行模型训练的避坑指南
-
python 用opencv调用训练好的模型进行识别的方法
-
计算机视觉(3):用inception-v3模型重新训练自己的数据模型
-
用pytorch搭建简单的语义分割(可训练自己的数据集)
-
PyTorch 迁移学习实践(几分钟即可训练好自己的模型)