欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

做一个基于PyTorch破解captcha验证码的接口

程序员文章站 2022-03-08 07:59:13
一.github下载源码地址:https://github.com/dee1024/pytorch-captcha-recognition根据他的操作步骤走没啥大问题,要生成几万张图片跑几个epoch才可以看到有一点点识别率,我7万张图片跑了30多epoch都没打到他说的准确率,目前只有80%左右二.加载model.pkl并做一个服务接口从源码可以看到训练每次读取64张图片,每100次会保存一个model.pkl文件如果想接着上一次训练完的模型训练,修改文件captcha_train.py,加上这...

一.github下载源码

地址:https://github.com/dee1024/pytorch-captcha-recognition

根据他的操作步骤走没啥大问题,要生成几万张图片跑几个epoch才可以看到有一点点识别率,我7万张图片跑了30多epoch都没打到他说的准确率,目前只有80%左右

二.加载model.pkl并做一个服务接口

从源码可以看到训练每次读取64张图片,每100次会保存一个model.pkl文件
如果想接着上一次训练完的模型训练,修改文件captcha_train.py,加上这两句

做一个基于PyTorch破解captcha验证码的接口

当已经训练到你想要的准确率时,就不用训练了,直接加载model.pkl出来用
直接新建一个app.py文件,并把model.pkl,captcha_cnn_model.py, captcha_setting.py放到当前目录

import base64
from flask import request
from flask import Flask
import numpy as np
import torch
from captcha_cnn_model import CNN
from PIL import Image
from torchvision import transforms
import captcha_setting

app=Flask(__name__)


# 加载模型
cnn = CNN()
cnn.eval()
cnn.load_state_dict(torch.load('model.pkl'))
# 定义路由
@app.route("/photo", methods=['POST'])
def get_frame():
    # 接收图片
    upload_file = request.files['file']
    image = Image.open(upload_file)
    # 数据归一化,要和训练时的一样,从my_dataset.py文件可以看到
    transform = transforms.Compose([
        transforms.Grayscale(),
        transforms.ToTensor(),
    ])
    # 现在的到的是三维数组(1, 60, 180)
    img_tensor = transform(image)
    # 因为训练时是按批次训练的,是个四维数组,所以输入的时候也要转换为四维数组,img_tensor.unsqueeze(0)==>(1, 1, 60, 180)
    predict_label = cnn(img_tensor.unsqueeze(0))
    # 根据索引位置识别出对应哪个字符
    c0 = captcha_setting.ALL_CHAR_SET[np.argmax(predict_label[0, 0:captcha_setting.ALL_CHAR_SET_LEN].data.numpy())]
    c1 = captcha_setting.ALL_CHAR_SET[np.argmax(predict_label[0, captcha_setting.ALL_CHAR_SET_LEN:2 * captcha_setting.ALL_CHAR_SET_LEN].data.numpy())]
    c2 = captcha_setting.ALL_CHAR_SET[np.argmax(predict_label[0, 2 * captcha_setting.ALL_CHAR_SET_LEN:3 * captcha_setting.ALL_CHAR_SET_LEN].data.numpy())]
    c3 = captcha_setting.ALL_CHAR_SET[np.argmax(predict_label[0, 3 * captcha_setting.ALL_CHAR_SET_LEN:4 * captcha_setting.ALL_CHAR_SET_LEN].data.numpy())]
    predict_label = '%s%s%s%s' % (c0, c1, c2, c3)
    return predict_label

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=5000)         

启动服务后,通过postman访问,没毛病

做一个基于PyTorch破解captcha验证码的接口

本文地址:https://blog.csdn.net/qq1607667079/article/details/110923837