欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

【Pytorch】用自己训练的resnet18模型进行推理

程序员文章站 2022-07-08 15:48:59
在网上看了这篇博客:【Pytorch】使用ResNet-50迁移学习进行图像分类训练https://blog.csdn.net/heiheiya/article/details/103028543有的小伙伴就有疑问,模型训练好后,怎么进行推理。于是,我写了这篇关于使用自己训练的resnet18模型进行推理,关于训练的部分请参考上面那篇。# *_* coding : UTF-8 *_*# 开发人员: csu·pan-_-||# 开发时间: 2020/12/29 19:16# 文件名称: re...

在网上看了这篇博客:

【Pytorch】使用ResNet-50迁移学习进行图像分类训练
https://blog.csdn.net/heiheiya/article/details/103028543

有的小伙伴就有疑问,模型训练好后,怎么进行推理。于是,我写了这篇关于使用自己训练的resnet18模型进行推理,训练的部分请参考上面那篇。

# *_* coding : UTF-8 *_*
# 开发人员: csu·pan-_-||
# 开发时间: 2020/12/29 19:16
# 文件名称: resnet_battery_infer.py
# 开发工具: PyCharm
# 功能描述: 用自己训练好的resnet18模型进行推理

import torch
import time
import numpy as np
import os
import cv2

modelPath =r'Battery\battery_resnet18.pt'  # 模型路径
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
validPath = r'Battery\valid\ok' # 需要测试的图片路径
txtPath = 'Battery/results'     # 结果存储路径
files = os.listdir(validPath)   # 展开图片文件列表
model = torch.load(modelPath)   # 加载模型
class_name = ['NG','OK']        # 类别名称
classList = []    # 方便计数

with torch.no_grad():
    model.eval()
    start = time.time()
    for j, input in enumerate(files):
        onestart = time.time()
        img = cv2.imread(os.path.join(validPath,input))
        # 转换图片格式
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x800x800
        img = np.ascontiguousarray(img)
        # 使用cuda
        img = torch.from_numpy(img).to(device)
        img = img.float()  # int转换为float
        img /= 255.0  # 归一化
        if img.ndimension() == 3:
            img = img.unsqueeze(0)   # 1x3x800x800
        outputs = model(img)
        ret, predictions = torch.max(outputs.data, 1)
        print('outputs.data: ',outputs.data)
        # outputs.data:  tensor([[6.4012e-04, 9.9936e-01]], device='cuda:0')
        # tensor转换成list:
        inferclass = predictions.cpu().numpy().tolist()[0]
        classList.append(class_name[inferclass])
        print('class: {:s}'.format(class_name[inferclass]))
        print('input: ',input)
        # class: 0 或 1
        oneend = time.time()
        print('one last: {:.4f}'.format(oneend-onestart))
        with open(txtPath + '/' + str(j) + '.txt', 'w') as f:
            f.write(class_name[inferclass])   # 存储结果为 'NG' 或 'OK'

    end = time.time()
    print('all last: {:.4f}s in {:d} imgs'.format(end - start,len(files)))
    # 统计类别数量:
    print('NG num: ',classList.count('NG'))
    print('OK num: ', classList.count('OK'))

【Pytorch】用自己训练的resnet18模型进行推理
在我的显卡上:TITAN RTX,还是蛮快的,推理一张图大概 14ms

本文地址:https://blog.csdn.net/qq_36563273/article/details/111962287