欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

薄板样条插值(Thin plate splines)算法对TotalText数据集的处理

程序员文章站 2022-04-17 15:54:25
...

最近需要对TotalText数据集进行一个处理,主要分为两部:

1、首先利用opencv将txt里标记的区域批量裁剪并保存下来。

2、然后利用TPS算法对裁剪后的图片进行“拉直”变换,并将输出格式定为200*64。效果展示:


薄板样条插值(Thin plate splines)算法对TotalText数据集的处理

                                                                                                                  裁剪后的原图 


薄板样条插值(Thin plate splines)算法对TotalText数据集的处理

                                                                                                          tps处理后“拉直”的图片


注意:

标记点为偶数个,故间隔为(点数/2 - 1)

import cv2
import numpy as np
# from PIL import Image
import os

def tps_for_total(num):

    # gt_path = '/home/this/桌面/img8/poly_gt_img.txt'
    # img_path = '/home/this/桌面/img8/img194.jpg'
    print('%d\n' %num)

    gt_path = '/home/this/下载/TotalText/examples/txt/do/poly_gt_img' + str(num) + '.txt'
    img_path = '/home/this/下载/TotalText/examples/jpg/do/img' + str(num) + '.jpg'

    TPS = cv2.createThinPlateSplineShapeTransformer()

    #gt = open(gt_path, 'r')
    if os.path.exists(gt_path):

        gt = open(gt_path, 'r')

        image = cv2.imread(img_path)
        # if image is None:
        #     print("none\n")
        #     continue

        ########################
        count = 1
        ########################

        for ii, line in enumerate(gt):
            image_ = image.copy()
            items = line.split(':')
            # print(items)
            assert items.__len__() == 5
            # print(items)
            type_ = items[3][4:5]
            # print(type_)


            xs = items[1][3:-5].split()
            ys = items[2][3:-8].split()
            assert xs.__len__() == ys.__len__()
            n = xs.__len__()
            points = []
            for i in range(n):
                point = [int(xs[i]), int(ys[i])]
                points.append(point)

            rect = cv2.boundingRect(np.array(points))
            word_image = image_[rect[1]:rect[1] + rect[3], rect[0]:rect[0] + rect[2]]

            w = image.shape[1]
            h = image.shape[0]

            # print(points)
            # print(points.__len__())

            point_number = int(points.__len__() / 2)


            length_unit = w / (point_number - 1)
            new_points = []
            for i in range(point_number):
                new_point = [length_unit * i, 0]
                new_points.append(new_point)
            for i in range(point_number):
                new_point = [w - (length_unit * i), h]
                new_points.append(new_point)
            matches = []
            for i in range(1, points.__len__()):
                matches.append(cv2.DMatch(i, i, 0))
            sourceShape = np.array(points, np.float32).reshape(1, -1, 2)
            targetShape = np.array(new_points, np.float32).reshape(1, -1, 2)

            if type_ != 'h' :
            #有m,c,h三类
                TPS.estimateTransformation(targetShape, sourceShape, matches)
                new_image = TPS.warpImage(image_)
            else:
                new_image=word_image
            # 数据集里标注为‘c’和'm'的需要tps变换,‘h'的不需要,若都变换,则注释掉


            TPS.estimateTransformation(targetShape, sourceShape, matches)
            new_image = TPS.warpImage(image_)
            new_image = cv2.resize(new_image, (200, 64))
            # new_image = cv2.resize(new_image, (100, 32))

            if (word_image is None) or (word_image.shape[0] <= 0) or (word_image.shape[1] <= 0):
                continue
            if (new_image is None) or (new_image.shape[0] <= 0) or (new_image.shape[1] <= 0):
                continue

            ###################
            # path_pic = "~/下载/TotalText/examples/result/traintrain/img8" +  "_" + str(count) + ".jpg"
            # new_image.save(path_pic)
            # count+=1

            savepath = '/home/this/桌面/aaa/img' + str(num) + '_' + str(count) + '.jpg'
            cv2.imwrite(savepath, new_image)
            count += 1

            ###################

            # cv2.imwrite('/home/this/桌面/img8/img8_1.jpg',new_image)
            # cv2.imshow('old', word_image)
            # cv2.imshow('new', new_image)
            # cv2.waitKey(0)


    else:

        print('None\n')

def main():
    for num in range(50,1600):
        tps_for_total(num)


if __name__ == '__main__':
    main()