python gdal + skimage实现基于遥感影像的传统图像分割及合并外加矢量化
程序员文章站
2024-03-24 22:36:10
...
根据我前述博客中对图像传分割算法及图像块合并方法的实验探究,在此将这些方法用于遥感影像并尝试矢量化。
这个过程中我自己遇到了一个棘手的问题,在最后的结果那里有描述,希望知道的朋友帮忙解答一下,谢谢!
直接上代码:
# -*- coding: utf-8 -*-
import os
import cv2
import gdal
from osgeo import ogr,osr
import numpy as np
from skimage import morphology, color
from skimage.segmentation import felzenszwalb, slic, quickshift
from skimage.segmentation import mark_boundaries
from skimage.util import img_as_float
from skimage.future import graph
def read_img(filename):
dataset=gdal.Open(filename)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_geotrans = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
im_data = dataset.ReadAsArray(0,0,im_width,im_height)
del dataset
return im_width,im_height,im_proj,im_geotrans,im_data
def write_img(filename,im_proj,im_geotrans,im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
def DoesDriverHandleExtension(drv, ext):
exts = drv.GetMetadataItem(gdal.DMD_EXTENSIONS)
return exts is not None and exts.lower().find(ext.lower()) >= 0
def GetExtension(filename):
ext = os.path.splitext(filename)[1]
if ext.startswith('.'):
ext = ext[1:]
return ext
def GetOutputDriversFor(filename):
drv_list = []
ext = GetExtension(filename)
for i in range(gdal.GetDriverCount()):
drv = gdal.GetDriver(i)
if (drv.GetMetadataItem(gdal.DCAP_CREATE) is not None or
drv.GetMetadataItem(gdal.DCAP_CREATECOPY) is not None) and \
drv.GetMetadataItem(gdal.DCAP_VECTOR) is not None:
if ext and DoesDriverHandleExtension(drv, ext):
drv_list.append(drv.ShortName)
else:
prefix = drv.GetMetadataItem(gdal.DMD_CONNECTION_PREFIX)
if prefix is not None and filename.lower().startswith(prefix.lower()):
drv_list.append(drv.ShortName)
return drv_list
def GetOutputDriverFor(filename):
drv_list = GetOutputDriversFor(filename)
ext = GetExtension(filename)
if not drv_list:
if not ext:
return 'ESRI Shapefile'
else:
raise Exception("Cannot guess driver for %s" % filename)
elif len(drv_list) > 1:
print("Several drivers matching %s extension. Using %s" % (ext if ext else '', drv_list[0]))
return drv_list[0]
def _weight_mean_color(graph, src, dst, n):
"""Callback to handle merging nodes by recomputing mean color.
The method expects that the mean color of `dst` is already computed.
Parameters
----------
graph : RAG
The graph under consideration.
src, dst : int
The vertices in `graph` to be merged.
n : int
A neighbor of `src` or `dst` or both.
Returns
-------
data : dict
A dictionary with the `"weight"` attribute set as the absolute
difference of the mean color between node `dst` and `n`.
"""
diff = graph.nodes[dst]['mean color'] - graph.nodes[n]['mean color']
diff = np.linalg.norm(diff)
return {'weight': diff}
def merge_mean_color(graph, src, dst):
"""Callback called before merging two nodes of a mean color distance graph.
This method computes the mean color of `dst`.
Parameters
----------
graph : RAG
The graph under consideration.
src, dst : int
The vertices in `graph` to be merged.
"""
graph.nodes[dst]['total color'] += graph.nodes[src]['total color']
graph.nodes[dst]['pixel count'] += graph.nodes[src]['pixel count']
graph.nodes[dst]['mean color'] = (graph.nodes[dst]['total color'] /
graph.nodes[dst]['pixel count'])
if __name__ == '__main__':
img_path = "E:/geo_test/test.tif"
temp_path = "E:/geo_test/temp/"
im_width,im_height,im_proj,im_geotrans,im_data = read_img(img_path)
temp = im_data.transpose((2,1,0))
segments_quick = quickshift(temp, kernel_size=3, max_dist=6, ratio=0.5)
mark0 = mark_boundaries(temp, segments_quick)
save_path = temp_path + "qs_seg0.tif"
re0 = mark0.transpose((2,1,0))
write_img(save_path,im_proj,im_geotrans,re0)
grid_path = temp_path + "qs_grid0.tif"
grid0 = np.uint8(re0[0,...])
write_img(grid_path,im_proj,im_geotrans,grid0)
skeleton = morphology.skeletonize(grid0)
border0 = np.multiply(grid0, skeleton)
ret,border0 = cv2.threshold(border0,0,1,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
border_path = temp_path + "qs_border0.tif"
write_img(border_path,im_proj,im_geotrans,border0)
g = graph.rag_mean_color(temp, segments_quick)
labels2 = graph.merge_hierarchical(segments_quick, g, thresh=5,
rag_copy=False,
in_place_merge=True,
merge_func=merge_mean_color,
weight_func=_weight_mean_color)
label_rgb2 = color.label2rgb(labels2, temp, kind='avg')
rgb_path = temp_path + "qs_label.tif"
lb = labels2.transpose((1,0))
write_img(rgb_path,im_proj,im_geotrans,lb)
mark = mark_boundaries(label_rgb2, labels2)
save_path = temp_path + "qs_seg.tif"
re = mark.transpose((2,1,0))
write_img(save_path,im_proj,im_geotrans,re)
grid_path = temp_path + "qs_grid.tif"
grid = np.uint8(re[0,...])
write_img(grid_path,im_proj,im_geotrans,grid)
skeleton = morphology.skeletonize(grid)
border = np.multiply(grid, skeleton)
ret,border = cv2.threshold(border,0,1,cv2.THRESH_BINARY+cv2.THRESH_OTSU)
border_path = temp_path + "qs_border.tif"
write_img(border_path,im_proj,im_geotrans,border)
# out_shp = temp_path + "temp.shp"
# RasterToLineshp(border_path, out_shp, 1)
border_driver = gdal.Open(rgb_path)
border_band = border_driver.GetRasterBand(1)
border_mask = border_band.GetMaskBand()
dst_filename = temp_path + 'temp.shp'
frmt = GetOutputDriverFor(dst_filename)
drv = ogr.GetDriverByName(frmt)
dst_ds = drv.CreateDataSource(dst_filename)
dst_layername = 'out'
srs = osr.SpatialReference(wkt=border_driver.GetProjection())
dst_layer = dst_ds.CreateLayer(dst_layername, geom_type=ogr.wkbPolygon, srs=srs)
# dst_layer = dst_ds.CreateLayer(dst_layername, geom_type=ogr.wkbLineString, srs=srs)
dst_fieldname = 'DN'
fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
dst_layer.CreateField(fd)
dst_field = 0
options = [""]
options.append('DATASET_FOR_GEOREF=' + rgb_path)
prog_func = gdal.TermProgress_nocb
gdal.Polygonize(border_band, border_mask, dst_layer, dst_field, options,
callback=prog_func)
srcband = None
src_ds = None
dst_ds = None
mask_ds = None
# enum WKBGeometryType {
# wkbPoint = 1,
# wkbLineString = 2,
# wkbPolygon = 3,
# wkbTriangle = 17
# wkbMultiPoint = 4,
# wkbMultiLineString = 5,
# wkbMultiPolygon = 6,
# wkbGeometryCollection = 7,
# wkbPolyhedralSurface = 15,
# wkbTIN = 16
# wkbPointZ = 1001,
# wkbLineStringZ = 1002,
# wkbPolygonZ = 1003,
# wkbTrianglez = 1017
# wkbMultiPointZ = 1004,
# wkbMultiLineStringZ = 1005,
# wkbMultiPolygonZ = 1006,
# wkbGeometryCollectionZ = 1007,
# wkbPolyhedralSurfaceZ = 1015,
# wkbTINZ = 1016
# wkbPointM = 2001,
# wkbLineStringM = 2002,
# wkbPolygonM = 2003,
# wkbTriangleM = 2017
# wkbMultiPointM = 2004,
# wkbMultiLineStringM = 2005,
# wkbMultiPolygonM = 2006,
# wkbGeometryCollectionM = 2007,
# wkbPolyhedralSurfaceM = 2015,
# wkbTINM = 2016
# wkbPointZM = 3001,
# wkbLineStringZM = 3002,
# wkbPolygonZM = 3003,
# wkbTriangleZM = 3017
# wkbMultiPointZM = 3004,
# wkbMultiLineStringZM = 3005,
# wkbMultiPolygonZM = 3006,
# wkbGeometryCollectionZM = 3007,
# wkbPolyhedralSurfaceZM = 3015,
# wkbTinZM = 3016,
# }
对应的结果图如下:
原图:
粗分割结果(代码中的qs_seg0.tif)
粗分割格网(代码中的qs_grid0.tif)
粗分割格网骨架(代码中的qs_border0.tif),格网的结果不是单线的,这里取了中心线。
合并后的分割结果(代码中的qs_seg.tif):
合并后的格网结果(代码中的qs_grid.tif)
合并后的格网骨架结果(代码中的qs_border.tif):
下面是矢量化以后的最终结果,这是代码中的qs_label.tif经过矢量化以后得到的结果,这里说明一下,之所以不用栅格线来直接转矢量线是因为我在GDAL里面并没有找到直接转化的方法,目前的方法强行转的话只能得到双线,完全不对,找了很久也没找到解决办法只能折中一下先得到面了,后面再面转线,看到的朋友如果知道的话烦请告知一下用什么办法可以直接把栅格线转为矢量线,要求脱离arcgis哈。
TO DO:
1.矢量面转线
2.线简化
3.线平滑
做完更新,感兴趣的朋友可以关注一下。