欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

骨架矢量化sknw源码研读

程序员文章站 2022-03-22 22:59:58
...

路网分割后得到region,提取骨架得到centerline,之后需要进行矢量化得到结点和边,进而转化成geojson格式进行生产。

本文对矢量化函数库sknw源码进行研读,并改进源码使结点和边之间紧密连接。

一、骨架提取并矢量化demo

from skimage.morphology import skeletonize
from skimage import data
import sknw
import numpy as np
import matplotlib.pyplot as plt

# 骨架提取
img = data.horse()
ske = skeletonize(~img).astype(np.uint16)

# 矢量化调用函数
graph = sknw.build_sknw(ske)


# draw image
plt.imshow(img, cmap='gray')

# draw edges by pts
for (s, e) in graph.edges():
    ps = graph[s][e]['pts']
    plt.plot(ps[:, 1], ps[:, 0], 'green')

# draw node by o
# node, nodes = graph._node, graph.nodes()
# ps = np.array([node[i]['o'] for i in nodes])
# plt.plot(ps[:, 1], ps[:, 0], 'r.')


# title and show
plt.title('Build Graph')
plt.show()
# plt.savefig('pc.png')

二、sknw源码研读:

①做一个像素的buffer,以免原图边界处找不到3*3邻域。

②将二值化图像进行结点映射,背景0,边1,结点2

③结点域提取,相邻结点组成一个结点域,对每个结点域进行索引编码,从10开始,依次递增。为了避免与映射012混淆,10代表第0个结点。

④遍历结点作为入口,遍历邻域寻找线,如果邻域内两次找到结点,则已遍历到的线作为这两个结点的连线。

⑤结点域找中心点

import numpy as np
from numba import jit
import networkx as nx
import matplotlib.pyplot as plt

# get neighbors d index
def neighbors(shape):
    """
    找出3*3大小的邻域,并压缩至向量的形式表示
    """
    dim = len(shape)
    block = np.ones([3]*dim)
    block[tuple([1]*dim)] = 0
    idx = np.where(block>0)
    idx = np.array(idx, dtype=np.uint8).T
    idx = np.array(idx-[1]*dim)
    acc = np.cumprod((1,)+shape[::-1][:-1])
    return np.dot(idx, acc[::-1])

@jit # my markimport m
def mark(img): # mark the array use (0, 1, 2)
    """
    将二值化的骨架图按照背景、线、结点的形式映射到0,1,2
    """
    nbs = neighbors(img.shape)
    H,W = img.shape
    img = img.ravel()
    for p in range(len(img)):
        if img[p]==0:continue
        s = 0
        for dp in nbs:
            if img[p+dp]!=0:s+=1
        if s==2:img[p]=1
        else:img[p]=2
    # image = np.zeros((H,W))
    # for i in range(len(img)):
    #     image[i//W,i-i//W*W] = img[i]
    # tmp = image[200:251,:51]
    # plt.imshow(tmp,cmap="gray")
    # plt.show()



@jit # trans index to r, c...
def idx2rc(idx, acc):
    """
    将一维向量形式的坐标映射到二维图像坐标
    """
    rst = np.zeros((len(idx), len(acc)), dtype=np.int16)
    for i in range(len(idx)):
        for j in range(len(acc)):
            rst[i,j] = idx[i]//acc[j]
            idx[i] -= rst[i,j]*acc[j]
    rst -= 1
    return rst
    
@jit # fill a node (may be two or more points)
def fill(img, p, num, nbs, acc, buf):
    """
    cur 当前遍历的结点,s 当前存储结点,该循环用以遍历所有相邻(8邻域)的node结点。
    return 二维list:以p点为中心进行拓展,找出包含p的所有密闭链接的结点。形式:[[node1_x,node2_y]...]
    """
    #back = 2
    back = img[p]
    img[p] = num
    #buf存储idx
    buf[0] = p
    cur = 0; s = 1;
    
    while True:
        p = buf[cur]
        for dp in nbs:
            cp = p+dp
            if img[cp]==back:
                img[cp] = num
                buf[s] = cp
                s+=1
        cur += 1
        if cur==s:break
    return idx2rc(buf[:s], acc)

@jit # trace the edge and use a buffer, then buf.copy, if use [] numba not works
def trace(img, p, nbs, acc, buf):
    """
    c1 头结点索引, c2 尾结点索引, 注意有着先后(小大)顺序,顺序不能乱,否则后续连线出现飞线
    newp 存储线上要遍历的下一个点。
    修改方法:我们将头尾结点添加至线的范围内,这样可以连接结点域内部结点间的线。从而生成封闭的拓扑。
    """
    c1 = 0; c2 = 0;
    newp = 0
    cur = 0
    while True:
        buf[cur] = p
        img[p] = 0
        cur += 1
        for dp in nbs:
            cp = p + dp
            if img[cp] >= 10:
                if c1==0:
                    c1=img[cp]
                    #add
                    # c1_p = cp
                else:
                    c2 = img[cp]
                    #add
                    # c2_p = cp
            if img[cp] == 1:
                newp = cp
        p = newp
        if c2!=0:break
    # #add
    # buf = np.insert(buf,0,c1_p)
    # #add
    # buf[cur+1] = c2_p
    # #add
    # cur += 2
    return (c1-10, c2-10, idx2rc(buf[:cur], acc))
   
@jit # parse the image then get the nodes and edges
def parse_struc(img):
    #img.shape H*W
    nbs = neighbors(img.shape)
    #acc: (W,1)
    acc = np.cumprod((1,)+img.shape[::-1][:-1])[::-1]
    img = img.ravel()
    #pts: 结点索引,平铺后
    pts = np.array(np.where(img==2))[0]
    buf = np.zeros(131072, dtype=np.int64)
    #num 结点索引,以10开始,为了避免mark(0 1 2)的干扰,所以从10开始代表第0个结点域(注意是一片连续的区域)。每存储一个结点域num+1。
    num = 10
    nodes = []
    for p in pts:
        if img[p] == 2:
            nds = fill(img, p, num, nbs, acc, buf)
            num += 1
            nodes.append(nds)

    edges = []
    for p in pts:
        for dp in nbs:
            if img[p+dp]==1:
                edge = trace(img, p+dp, nbs, acc, buf)
                edges.append(edge)

    return nodes, edges
    
# use nodes and edges build a networkx graph
def build_graph(nodes, edges, multi=False):
    graph = nx.MultiGraph() if multi else nx.Graph()
    for i in range(len(nodes)):
        graph.add_node(i, pts=nodes[i], o=nodes[i].mean(axis=0))
    for s,e,pts in edges:
        l = np.linalg.norm(pts[1:]-pts[:-1], axis=1).sum()
        graph.add_edge(s,e, pts=pts, weight=l)
    return graph

def buffer(ske):
    #扩充一个单位像素的边,以便对原图中每个像素进行八邻域查找
    buf = np.zeros(tuple(np.array(ske.shape)+2), dtype=np.uint16)
    buf[tuple([slice(1,-1)]*buf.ndim)] = ske
    return buf

def build_sknw(ske, multi=False):
    buf = buffer(ske)
    mark(buf)
    nodes, edges = parse_struc(buf)
    return build_graph(nodes, edges, multi)
    
# draw the graph
def draw_graph(img, graph, cn=255, ce=128):
    acc = np.cumprod((1,)+img.shape[::-1][:-1])[::-1]
    img = img.ravel()
    for idx in graph.nodes():
        pts = graph.node[idx]['pts']
        img[np.dot(pts, acc)] = cn
    for (s, e) in graph.edges():
        eds = graph[s][e]
        for i in eds:
            pts = eds[i]['pts']
            img[np.dot(pts, acc)] = ce

if __name__ == '__main__':
    g = nx.MultiGraph()
    g.add_nodes_from([1,2,3,4,5])
    g.add_edges_from([(1,2),(1,3),(2,3),(4,5),(5,4)])
    print(g.nodes())
    print(g.edges())
    a = g.subgraph(1)
    print('d')
    print(a)
    print('d')
    

最后输出每个结点域的中心点和线。但是存在结点与线分离的情况。如下图所示:

骨架矢量化sknw源码研读

三、添加代码:结点与边相连

分析不相连的原因:结点域以结点中心区域表示,故其它结点不会显示,于是存在断线。我的思路是从线的端点开始,向外遍历并入非结点中心的结点作为线的扩充点即可。注意线的顺序,因为线的存储是有序的,头结点和尾结点的顺序正好相反。添加代码如下:

#add
def join_nodes(graph):
    node, nodes = graph._node, graph.nodes()
    center_node = np.array([node[i]['o'] for i in nodes])
    all_nodes = np.array([node[i]['pts'] for i in nodes])
    for (s, e) in graph.edges():
        ps = graph[s][e]['pts']
        s_center_node = center_node[s]
        e_center_node = center_node[e]
        s_all_nodes = all_nodes[s]
        e_all_nodes = all_nodes[e]
        s_line_point = ps[0]
        e_line_point = ps[-1]
        #线长度为一的不进行扩展,以免后续清洗不掉
        if len(ps)==1:
            continue
        if len(s_all_nodes)==1:
            graph[s][e]['pts'] = np.vstack((s_center_node,graph[s][e]['pts']))
        else:
            bbox = [min(s_center_node[0],s_line_point[0]),max(s_center_node[0],s_line_point[0]),
                    min(s_center_node[1],s_line_point[1]),max(s_center_node[1],s_line_point[1])]
            s_crop_nodes = [i for i in s_all_nodes if i[0]>=bbox[0] and i[0]<=bbox[1] and i[1]>=bbox[2] and i[1]<=bbox[3]][::-1]
            for i in s_crop_nodes:
                graph[s][e]['pts'] = np.vstack((np.array(i),graph[s][e]['pts']))

        if len(e_all_nodes)==1:
            graph[s][e]['pts'] = np.vstack((graph[s][e]['pts'],e_center_node))
        else:
            bbox = [min(e_center_node[0],e_line_point[0]),max(e_center_node[0],e_line_point[0]),
                    min(e_center_node[1],e_line_point[1]),max(e_center_node[1],e_line_point[1]),]
            e_crop_nodes = [i for i in e_all_nodes if i[0]>=bbox[0] and i[0]<=bbox[1] and i[1]>=bbox[2] and i[1]<=bbox[3]][::-1]
            for i in e_crop_nodes:
                graph[s][e]['pts'] = np.vstack((graph[s][e]['pts'],np.array(i)))
    return graph
graph = join_nodes(graph)

后处理改进结果:

骨架矢量化sknw源码研读