欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化

程序员文章站 2024-03-24 22:36:10
...

根据我前述博客中对图像传分割算法及图像块合并方法的实验探究,在此将这些方法用于遥感影像并尝试矢量化。
这个过程中我自己遇到了一个棘手的问题,在最后的结果那里有描述,希望知道的朋友帮忙解答一下,谢谢!
直接上代码:

# -*- coding: utf-8 -*-
import os
import cv2
import gdal
from osgeo import ogr,osr
import numpy as np
from skimage import morphology, color
from skimage.segmentation import felzenszwalb, slic, quickshift
from skimage.segmentation import mark_boundaries
from skimage.util import img_as_float
from skimage.future import graph

def read_img(filename):
    dataset=gdal.Open(filename)

    im_width = dataset.RasterXSize
    im_height = dataset.RasterYSize

    im_geotrans = dataset.GetGeoTransform()
    im_proj = dataset.GetProjection()
    im_data = dataset.ReadAsArray(0,0,im_width,im_height)

    del dataset 
    return im_width,im_height,im_proj,im_geotrans,im_data

def write_img(filename,im_proj,im_geotrans,im_data):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    else:
        im_bands, (im_height, im_width) = 1,im_data.shape 

    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)

    dataset.SetGeoTransform(im_geotrans)
    dataset.SetProjection(im_proj)
    if im_bands == 1:
        dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            dataset.GetRasterBand(i+1).WriteArray(im_data[i])

    del dataset

def DoesDriverHandleExtension(drv, ext):
    exts = drv.GetMetadataItem(gdal.DMD_EXTENSIONS)
    return exts is not None and exts.lower().find(ext.lower()) >= 0


def GetExtension(filename):
    ext = os.path.splitext(filename)[1]
    if ext.startswith('.'):
        ext = ext[1:]
    return ext


def GetOutputDriversFor(filename):
    drv_list = []
    ext = GetExtension(filename)
    for i in range(gdal.GetDriverCount()):
        drv = gdal.GetDriver(i)
        if (drv.GetMetadataItem(gdal.DCAP_CREATE) is not None or
            drv.GetMetadataItem(gdal.DCAP_CREATECOPY) is not None) and \
           drv.GetMetadataItem(gdal.DCAP_VECTOR) is not None:
            if ext and DoesDriverHandleExtension(drv, ext):
                drv_list.append(drv.ShortName)
            else:
                prefix = drv.GetMetadataItem(gdal.DMD_CONNECTION_PREFIX)
                if prefix is not None and filename.lower().startswith(prefix.lower()):
                    drv_list.append(drv.ShortName)

    return drv_list

def GetOutputDriverFor(filename):
    drv_list = GetOutputDriversFor(filename)
    ext = GetExtension(filename)
    if not drv_list:
        if not ext:
            return 'ESRI Shapefile'
        else:
            raise Exception("Cannot guess driver for %s" % filename)
    elif len(drv_list) > 1:
        print("Several drivers matching %s extension. Using %s" % (ext if ext else '', drv_list[0]))
    return drv_list[0]

def _weight_mean_color(graph, src, dst, n):
    """Callback to handle merging nodes by recomputing mean color.
    The method expects that the mean color of `dst` is already computed.
    Parameters
    ----------
    graph : RAG
        The graph under consideration.
    src, dst : int
        The vertices in `graph` to be merged.
    n : int
        A neighbor of `src` or `dst` or both.

    Returns
    -------
    data : dict
        A dictionary with the `"weight"` attribute set as the absolute
        difference of the mean color between node `dst` and `n`.
    """
    diff = graph.nodes[dst]['mean color'] - graph.nodes[n]['mean color']
    diff = np.linalg.norm(diff)
    return {'weight': diff}

def merge_mean_color(graph, src, dst):
    """Callback called before merging two nodes of a mean color distance graph.
    This method computes the mean color of `dst`.
    Parameters
    ----------
    graph : RAG
        The graph under consideration.
    src, dst : int
        The vertices in `graph` to be merged.
    """
    graph.nodes[dst]['total color'] += graph.nodes[src]['total color']
    graph.nodes[dst]['pixel count'] += graph.nodes[src]['pixel count']
    graph.nodes[dst]['mean color'] = (graph.nodes[dst]['total color'] /
                                      graph.nodes[dst]['pixel count'])

if __name__ == '__main__':
    img_path = "E:/geo_test/test.tif"
    temp_path = "E:/geo_test/temp/"
    im_width,im_height,im_proj,im_geotrans,im_data = read_img(img_path)  
    temp = im_data.transpose((2,1,0))
    segments_quick = quickshift(temp, kernel_size=3, max_dist=6, ratio=0.5)
    
    mark0 = mark_boundaries(temp, segments_quick)
    save_path = temp_path + "qs_seg0.tif"
    re0 = mark0.transpose((2,1,0))
    write_img(save_path,im_proj,im_geotrans,re0)

    grid_path = temp_path + "qs_grid0.tif"
    grid0 = np.uint8(re0[0,...])
    write_img(grid_path,im_proj,im_geotrans,grid0)

    skeleton = morphology.skeletonize(grid0)
    border0 = np.multiply(grid0, skeleton)
    ret,border0 = cv2.threshold(border0,0,1,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    border_path = temp_path + "qs_border0.tif"
    write_img(border_path,im_proj,im_geotrans,border0)
    
    g = graph.rag_mean_color(temp, segments_quick)
    labels2 = graph.merge_hierarchical(segments_quick, g, thresh=5, 
		      rag_copy=False,
              in_place_merge=True,
              merge_func=merge_mean_color,
              weight_func=_weight_mean_color)
    label_rgb2 = color.label2rgb(labels2, temp, kind='avg')
    rgb_path = temp_path + "qs_label.tif"
    lb = labels2.transpose((1,0))
    write_img(rgb_path,im_proj,im_geotrans,lb)
    
    mark = mark_boundaries(label_rgb2, labels2)
    save_path = temp_path + "qs_seg.tif"
    re = mark.transpose((2,1,0))
    write_img(save_path,im_proj,im_geotrans,re)

    grid_path = temp_path + "qs_grid.tif"
    grid = np.uint8(re[0,...])
    write_img(grid_path,im_proj,im_geotrans,grid)

    skeleton = morphology.skeletonize(grid)
    border = np.multiply(grid, skeleton)
    ret,border = cv2.threshold(border,0,1,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
    border_path = temp_path + "qs_border.tif"
    write_img(border_path,im_proj,im_geotrans,border)

    # out_shp = temp_path + "temp.shp"
    # RasterToLineshp(border_path, out_shp, 1)

    border_driver = gdal.Open(rgb_path)
    border_band = border_driver.GetRasterBand(1)
    border_mask = border_band.GetMaskBand()

    dst_filename = temp_path + 'temp.shp'
    frmt = GetOutputDriverFor(dst_filename)
    drv = ogr.GetDriverByName(frmt)
    dst_ds = drv.CreateDataSource(dst_filename)
    
    dst_layername = 'out'
    srs = osr.SpatialReference(wkt=border_driver.GetProjection())
    dst_layer = dst_ds.CreateLayer(dst_layername, geom_type=ogr.wkbPolygon, srs=srs)
    # dst_layer = dst_ds.CreateLayer(dst_layername, geom_type=ogr.wkbLineString, srs=srs)


    dst_fieldname = 'DN'
    fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
    dst_layer.CreateField(fd)
    dst_field = 0

    options = [""]
    options.append('DATASET_FOR_GEOREF=' + rgb_path)
    prog_func = gdal.TermProgress_nocb
    gdal.Polygonize(border_band, border_mask, dst_layer, dst_field, options,
                         callback=prog_func)

    srcband = None
    src_ds = None
    dst_ds = None
    mask_ds = None

# enum WKBGeometryType {
# wkbPoint = 1,
# wkbLineString = 2,
# wkbPolygon = 3,
# wkbTriangle = 17
# wkbMultiPoint = 4,
# wkbMultiLineString = 5,
# wkbMultiPolygon = 6,
# wkbGeometryCollection = 7,
# wkbPolyhedralSurface = 15,
# wkbTIN = 16
# wkbPointZ = 1001,
# wkbLineStringZ = 1002,
# wkbPolygonZ = 1003,
# wkbTrianglez = 1017
# wkbMultiPointZ = 1004,
# wkbMultiLineStringZ = 1005,
# wkbMultiPolygonZ = 1006,
# wkbGeometryCollectionZ = 1007,
# wkbPolyhedralSurfaceZ = 1015,
# wkbTINZ = 1016
# wkbPointM = 2001,
# wkbLineStringM = 2002,
# wkbPolygonM = 2003,
# wkbTriangleM = 2017
# wkbMultiPointM = 2004,
# wkbMultiLineStringM = 2005,
# wkbMultiPolygonM = 2006,
# wkbGeometryCollectionM = 2007,
# wkbPolyhedralSurfaceM = 2015,
# wkbTINM = 2016
# wkbPointZM = 3001,
# wkbLineStringZM = 3002,
# wkbPolygonZM = 3003,
# wkbTriangleZM = 3017
# wkbMultiPointZM = 3004,
# wkbMultiLineStringZM = 3005,
# wkbMultiPolygonZM = 3006,
# wkbGeometryCollectionZM = 3007,
# wkbPolyhedralSurfaceZM = 3015,
# wkbTinZM = 3016,
# }

对应的结果图如下:
原图:
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化
粗分割结果(代码中的qs_seg0.tif)
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化
粗分割格网(代码中的qs_grid0.tif)
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化
粗分割格网骨架(代码中的qs_border0.tif),格网的结果不是单线的,这里取了中心线。
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化
合并后的分割结果(代码中的qs_seg.tif):
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化
合并后的格网结果(代码中的qs_grid.tif)
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化
合并后的格网骨架结果(代码中的qs_border.tif):
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化
下面是矢量化以后的最终结果,这是代码中的qs_label.tif经过矢量化以后得到的结果,这里说明一下,之所以不用栅格线来直接转矢量线是因为我在GDAL里面并没有找到直接转化的方法,目前的方法强行转的话只能得到双线,完全不对,找了很久也没找到解决办法只能折中一下先得到面了,后面再面转线,看到的朋友如果知道的话烦请告知一下用什么办法可以直接把栅格线转为矢量线,要求脱离arcgis哈。
python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化

TO DO:
1.矢量面转线
2.线简化
3.线平滑
做完更新,感兴趣的朋友可以关注一下。

相关标签: python gdal