欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

利用scipy包计算表格线的峰值

程序员文章站 2022-07-12 22:31:10
...
import cv2
import numpy as np
from scipy.signal import find_peaks, peak_widths


def get_lines_from_image(img_bin, axis, kernel_len_div = 20, kernel_len = None, iters = 3):
    """
    :param img_bin: opencv img
    :param axis: 0 对应竖直, 1对应水平线
    :param kernel_len_div: 相对于边长的几分之几
    :param kernel_len: 直接给定和长度,如果这个长度不为0, 上述例子失效
    :return:
    """
    DEBUG = True
    # Defining a kernel length
    if kernel_len is not None:
        assert kernel_len > 0
        kernel_length = kernel_len
    else:
        kernel_length = max(np.array(img_bin).shape[axis] // kernel_len_div, 1)

    if axis == 0:
        # A verticle kernel of (1 X kernel_length), which will detect all the verticle lines from the image.
        verticle_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, kernel_length))

        # Morphological operation to detect verticle lines from an image
        img_temp1 = cv2.erode(img_bin, verticle_kernel, iterations=iters)
        verticle_lines_img = cv2.dilate(img_temp1, verticle_kernel, iterations=iters)
        if DEBUG:
            cv2.imwrite("verticle_lines.jpg", verticle_lines_img)
        return verticle_lines_img

    else:
        # A horizontal kernel of (kernel_length X 1), which will help to detect all the horizontal line from the image.
        hori_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (kernel_length, 1))

        # Morphological operation to detect horizontal lines from an image
        img_temp2 = cv2.erode(img_bin, hori_kernel, iterations=iters)
        horizontal_lines_img = cv2.dilate(img_temp2, hori_kernel, iterations=iters)
        if DEBUG:
            cv2.imwrite("horizontal_lines.jpg", horizontal_lines_img)
        return horizontal_lines_img

def line_img_add(verticle_lines_img, horizontal_lines_img):
    # 把检测出来的横线和竖线相加
    alpha = 0.5
    beta = 1.0 - alpha
    img_final_bin = cv2.addWeighted(verticle_lines_img, alpha, horizontal_lines_img, beta, 0.0)
    return img_final_bin


def project(np_arr, axis):
    # 水平或垂直投影, 0竖直,1水平
    return np.count_nonzero(np_arr == 0, axis=axis)

def get_grid_coordinate(img_bin, prominence_ratio = 0.3, height_ratio=None, distance=None, DEBUG=0):
    """
    计算格点水平(x)和竖直(y)坐标和线宽
    :param img_bin: 白底黑线
    :return:
    """
    #参数
    # prominence_ratio 峰值的突出程度, 相对于表格长宽
    h, w = img_bin.shape
    # print("size",h,w)
    x_prj = project(img_bin, 0)
    y_prj = project(img_bin, 1)
    # 检测峰值
    # high_ratio = 0.1 # todo 这也是一个参数
    height_x = height_y = None
    if height_ratio is not None:
        height_x = height_ratio * h
        height_y = height_ratio * w
    # x_peaks, _ = find_peaks(x_prj, height=high_ratio*h, distance = max(1,w/20), prominence=(h*prominence_ratio, None))
    # y_peaks, _ = find_peaks(y_prj, height=high_ratio*w, distance = max(1,w/50), prominence=(w*prominence_ratio, None))
    print('height_x,height_y:', height_x, height_y)
    x_peaks, _ = find_peaks(x_prj, height=height_x, distance=distance,  prominence=(h * prominence_ratio, None))
    y_peaks, _ = find_peaks(y_prj, height=height_y, distance=distance, prominence=(w * prominence_ratio, None))

    x_peaks = list(x_peaks)
    y_peaks = list(y_peaks)

    DEBUG =True
    if DEBUG:
        #plot
        import matplotlib.pyplot as plt
        img = img_bin
        plt.subplot(211)
        plt.title("x")
        print('range(x_prj.shape[0]):',range(x_prj.shape[0]))
        plt.plot(range(x_prj.shape[0]), x_prj)
        plt.plot(x_peaks, x_prj[x_peaks], "x")
        plt.subplot(212)
        plt.title("y")
        plt.plot(range(y_prj.shape[0]), y_prj)
        plt.plot(y_peaks, y_prj[y_peaks], "x")
        plt.show()

    if len(x_peaks) == 0: # 如果没检测到峰值, 把检测框边界峰值
        x_peaks = [0, w]
        print("x_peaks is None !!!!!!!")
    if len(y_peaks) == 0:
        y_peaks = [0, h]
        print("y_peaks is None !!!!!!!")

    # 计算线宽, 假设线宽一定, 横有m根线, 竖有n根线, 表格高为h, 宽为w, 线宽为x
    # n_nonzero = m*w*x + n*h*x - m*n*x^2
    # n_nonzero 约等于 m*w*x + n*h*x
    h,w = img_bin.shape
    m,n = len(y_peaks), len(x_peaks)
    line_width = np.count_nonzero(img_bin == 0) / (m*w + n*h)
    line_width = round(line_width) + 1
    return x_peaks, y_peaks, line_width

if __name__ == '__main__':
    path= './test_page_debug_out_debug/table_crop_fix_rm_char.jpg'
    img = cv2.imread(path)
    img_bin = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    verticle_lines_img = get_lines_from_image(img_bin, 0, kernel_len_div=40)
    horizontal_lines_img = get_lines_from_image(img_bin, 1, kernel_len_div=40)
    # 表格线提取
    img_final_bin_lines = line_img_add(verticle_lines_img, horizontal_lines_img)
    cv2.imwrite('./img_final_bin_lines.jpg',img_final_bin_lines)
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
    # 膨胀并二值化
    img_final_bin_lines = cv2.erode(~img_final_bin_lines, kernel, iterations=2)
    (thresh, img_final_bin_lines) = cv2.threshold(img_final_bin_lines, 128, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
    cv2.imwrite('./img_final_bin_lines_fix.jpg', img_final_bin_lines)
    # 根据表格线计算格点坐标 -----------------------------------
    x_grids, y_grids, line_w = get_grid_coordinate(img_final_bin_lines)
    

输入:

利用scipy包计算表格线的峰值

提取竖直线:

利用scipy包计算表格线的峰值

提取水平线:

利用scipy包计算表格线的峰值

水平线与竖直线峰值查找:

利用scipy包计算表格线的峰值

相关标签: scipy