做一个基于PyTorch破解captcha验证码的接口
程序员文章站
2022-06-17 14:42:57
一.github下载源码地址:https://github.com/dee1024/pytorch-captcha-recognition根据他的操作步骤走没啥大问题,要生成几万张图片跑几个epoch才可以看到有一点点识别率,我7万张图片跑了30多epoch都没打到他说的准确率,目前只有80%左右二.加载model.pkl并做一个服务接口从源码可以看到训练每次读取64张图片,每100次会保存一个model.pkl文件如果想接着上一次训练完的模型训练,修改文件captcha_train.py,加上这...
一.github下载源码
地址:https://github.com/dee1024/pytorch-captcha-recognition
根据他的操作步骤走没啥大问题,要生成几万张图片跑几个epoch才可以看到有一点点识别率,我7万张图片跑了30多epoch都没打到他说的准确率,目前只有80%左右
二.加载model.pkl并做一个服务接口
从源码可以看到训练每次读取64张图片,每100次会保存一个model.pkl文件
如果想接着上一次训练完的模型训练,修改文件captcha_train.py,加上这两句
当已经训练到你想要的准确率时,就不用训练了,直接加载model.pkl出来用
直接新建一个app.py文件,并把model.pkl,captcha_cnn_model.py, captcha_setting.py放到当前目录
import base64
from flask import request
from flask import Flask
import numpy as np
import torch
from captcha_cnn_model import CNN
from PIL import Image
from torchvision import transforms
import captcha_setting
app=Flask(__name__)
# 加载模型
cnn = CNN()
cnn.eval()
cnn.load_state_dict(torch.load('model.pkl'))
# 定义路由
@app.route("/photo", methods=['POST'])
def get_frame():
# 接收图片
upload_file = request.files['file']
image = Image.open(upload_file)
# 数据归一化,要和训练时的一样,从my_dataset.py文件可以看到
transform = transforms.Compose([
transforms.Grayscale(),
transforms.ToTensor(),
])
# 现在的到的是三维数组(1, 60, 180)
img_tensor = transform(image)
# 因为训练时是按批次训练的,是个四维数组,所以输入的时候也要转换为四维数组,img_tensor.unsqueeze(0)==>(1, 1, 60, 180)
predict_label = cnn(img_tensor.unsqueeze(0))
# 根据索引位置识别出对应哪个字符
c0 = captcha_setting.ALL_CHAR_SET[np.argmax(predict_label[0, 0:captcha_setting.ALL_CHAR_SET_LEN].data.numpy())]
c1 = captcha_setting.ALL_CHAR_SET[np.argmax(predict_label[0, captcha_setting.ALL_CHAR_SET_LEN:2 * captcha_setting.ALL_CHAR_SET_LEN].data.numpy())]
c2 = captcha_setting.ALL_CHAR_SET[np.argmax(predict_label[0, 2 * captcha_setting.ALL_CHAR_SET_LEN:3 * captcha_setting.ALL_CHAR_SET_LEN].data.numpy())]
c3 = captcha_setting.ALL_CHAR_SET[np.argmax(predict_label[0, 3 * captcha_setting.ALL_CHAR_SET_LEN:4 * captcha_setting.ALL_CHAR_SET_LEN].data.numpy())]
predict_label = '%s%s%s%s' % (c0, c1, c2, c3)
return predict_label
if __name__ == "__main__":
app.run(host='0.0.0.0', port=5000)
启动服务后,通过postman访问,没毛病
本文地址:https://blog.csdn.net/qq1607667079/article/details/110923837