欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

利用Python的Scikit-Learn库对遥感影像进行随机森林分类

程序员文章站 2022-07-14 14:51:31
...

利用Python的Scikit-Learn库对遥感影像进行随机森林(RandomForest)分类

随机森林是一个包含多个决策树的分类器,因其运算速度快、分类精度高、算法稳定等特点,被广泛应用到遥感图像的分类研究中。Scikit-Learn作为Python 编程语言的免费软件机器学习库,提供了对随机森林算法的支持,但没有提供针对遥感影像分类的相关函数。因此,本篇文章将为读者介绍利用Python及其扩展包Scikit-Learn对遥感影像进行随机森林分类的完整过程,包括:ShapeFile格式样本数据的读取、栅格数据读取和裁剪、利用Scikit-Learn的RandomForestClassifier模块进行样本训练和遥感影像分类。

一、Scikit-Learn的安装

直接执行命令pip install scikit-learn,所有依赖库都会自动安装。安装完成后,添加代码from sklearn.ensemble import RandomForestClassifier即可使用

二、样本绘制

在ArcGIS中绘制训练样本,格式为shpfile,可以是点类型或线类型,建立Value字段,用于存储分类编号,例如1-湿地,2-湖泊,3-水稻。

三、栅格数据裁剪

通过shp样本获取对应的栅格值,需要使用多边形裁剪栅格数据,我们使用射线算法。全部代码如下(TrainByRandomForest.py):

代码如下(示例):

from sklearn.ensemble import RandomForestClassifier
import osgeo
from osgeo import gdal
from osgeo import osr
from osgeo import ogr
def GetSubRaster(inraster,polygonPoints:list):
    polygonPoints.append(polygonPoints[0])#面多边形坐标封闭
	print("当前多边形节点数量:"+str(len(polygonPoints)))
	#计算最小边界矩形
    minX=10000000000000
    maxX=-minX
    minY=100000000000000000
    maxY=-minY

    for point in polygonPoints:
        if point.X<minX:minX=point.X
        if point.X>maxX:maxX=point.X
        if point.Y<minY:minY=point.Y
        if point.Y>maxY:maxY=point.Y
    leftX=minX
    upY=maxY
    rightX=maxX
    bottomY=minY

    rds = gdal.Open(inraster) 
    transform = (rds.GetGeoTransform())
    lX = transform[0]#左上角点
    lY = transform[3]
    rX = transform[1]#分辨率
    rY = transform[5]

    wpos=int((leftX-lX)/rX)
    hpos=int((upY-lY)/rY)

    width=int((rightX-leftX)/rX)
    height=int((bottomY-upY)/rY)
    BandsCount = rds.RasterCount
    arr = rds.ReadAsArray(wpos,hpos,width,height)
    fixX=list()
    nodatavalue=rds.GetRasterBand(1).GetNoDataValue()
    for i in range(height):
        if height>200:
            print(f"多边形裁剪进度:{round(((i+1)/height)*100,4)}%")
        Y=upY+i*rY+.00001
	    #射线算法只需要比对多边形的一条水平线上的边
        pointsindex=list()
        for k in range(len(polygonPoints)-1):
                point1=polygonPoints[k]
                point2=polygonPoints[k+1]
                if (point1.Y>=Y and point2.Y<=Y) or (point1.Y<=Y and point2.Y>=Y):
                    pointsindex.append(k)
        for j in range(width):
            count=0
            for m in (pointsindex):
                point1=polygonPoints[m]
                point2=polygonPoints[m+1]
                X=leftX+j*rX+.00001
                if point1.X==point2.X:
                    intersectX=point1.X
                    if intersectX>X:count+=1
                else:
                    k=(point2.Y-point1.Y)/(point2.X-point1.X)
                    if k==0:
                        if X<point1.X or X<point2.X:
                            count+=1
                    else:
                        intersectX=(Y-point1.Y)/k+point1.X
                        if intersectX>X:count+=1

            if count%2==0:
                if BandsCount>1:
                    for bc in range(BandsCount):
                        arr[bc][i][j]=(nodatavalue)
                else:
                    arr[i][j]=-1
    #为了测试结果的正确性,可以先将其写到硬盘
	#WriteRaster("test.tif",arr,inraster,width,height,BandsCount,leftX,upY)
    return arr,width,height,BandsCount,leftX,upY


四、创建分类器并训练样本

代码如下(示例):

def createClassifier(inraster,inshp,field:str="Id",treenum:int=100):
    rasterspatial = gdal.Open(inraster)
    spatial2=osr.SpatialReference()
    spatial2.ImportFromWkt(rasterspatial.GetProjectionRef())
    shpspatial=ogr.Open(inshp)
    layer=shpspatial.GetLayer(0)
    spatial1=layer.GetSpatialRef()
   
    ct=osr.CreateCoordinateTransformation(spatial1,spatial2)
    oFeature = layer.GetNextFeature()
    # 下面开始遍历图层中的要素
    geom=oFeature.GetGeometryRef()
    if geom.GetGeometryType()==ogr.wkbPoint:
        return createClassifierByPoint(inraster,inshp)
    k=geom.GetGeometryType()
    if geom.GetGeometryType()!=ogr.wkbPolygon:
        print("样本必须为单部件多边形")
        return False
    trainX = list()
    trainY = list()
    print("读取样本")
    while oFeature is not None:
        geom=oFeature.GetGeometryRef()
        wkt=geom.ExportToWkt()
        points=WKTToPoints(wkt)
        polygonPoints=[]
        value=oFeature.GetField(field)
        for point in points:
            pC=ct.TransformPoint(point.X,point.Y,0)
            polygonPoints.append(Point(pC[0],pC[1]))

        arr,width,height,BandsCount,leftX,upY=GetSubRaster(inraster,polygonPoints)
        for i in range(height):
            for k in range(width):
                nodata=True
                tem = list()
                for bc in range(BandsCount):
                    v=int(arr[bc][i][k])
                    tem.append(v)
                    if v>0:nodata=False
                if nodata:
                    continue
                trainX.append(tem)
                trainY.append(int(value))
        oFeature = layer.GetNextFeature()

    ct=None
    spatial1=None
    spatial2=None
    print("训练样本")
    clf = RandomForestClassifier(n_estimators=treenum)
    clf.fit(trainX, trainY)#训练样本
    print("训练完成")
    return clf


五、随机森林分类

from sklearn.ensemble import RandomForestClassifier
import osgeo
from osgeo import gdal
from osgeo import osr
from osgeo import ogr
import numpy
import os
import sys
import TrainByRandomForest as tbrf
def RandomForestClassification(ClassifyRaster,TrainRaster,TrainShp,outfile,blockSize=0,treenum=100,max_depth=10): 
    rds = gdal.Open(ClassifyRaster) 
    #print((rds.GetRasterBand(1).DataType))
    transform = (rds.GetGeoTransform())
    lX = transform[0]#左上角点
    lY = transform[3]
    rX = transform[1]#分辨率
    rY = transform[5]
    width = rds.RasterXSize
    height = rds.RasterYSize
    bX = lX + rX * width#右下角点
    bY = lY + rY * height
    BandsCount = rds.RasterCount
    clf = tbrf.createClassifier(TrainRaster,TrainShp)
    Z = list()
    fixX = list()
    if blockSize == 0:
        p,a = memory_usage()
        pv = 0.6 / 10000
        checkMemory(2000)#内存小于2GB,不在计算
        bl = (a - 2000) / pv / height / BandsCount
        blockSize = math.ceil(height / bl)
        if blockSize < 1:blockSize = 1
        if blockSize > 1:blockSize+=5
    if  blockSize != 1:
        blockHeight = 0
        modHeight = 0
        modHeight = height % blockSize

        if modHeight == 0:
            blockHeight = int(height / blockSize)
            
        else:
            blockHeight = int(height / blockSize)
        print(f"分块大小{width}*{blockHeight}")
        for bs in range(blockSize):
            print(f"计算块{bs+1}/{blockSize}")
            checkMemory(500)
            arr = rds.ReadAsArray(0,bs * blockHeight,width,blockHeight)
            
            for i in range(blockHeight):
                print(f"分块:{bs+1}/{blockSize}添加分类数据{round((i+1)*100/blockHeight,4)}%")
                for k in range(width):
                    tem = list()
                    for bc in range(BandsCount):
                        tem.append(int(arr[bc][i][k]))
                    fixX.append(tem)
            print(f"分块:{bs+1}/{blockSize}计算分类结果……")
            checkMemory(800)
            z = clf.predict(fixX)         
            Z.extend(z)
            fixX = list()
            arr = None
        print(f"计算余数:{width}*{modHeight}")
        checkMemory(500)
        arr = rds.ReadAsArray(0,blockSize * blockHeight,width,modHeight)
        if modHeight > 0:
            for i in range(modHeight):
                print(f"余块:添加分类数据{round((i+1)*100/modHeight,4)}%")
                for k in range(width):
                    tem = list()
                    for bc in range(BandsCount):
                        tem.append(int(arr[bc][i][k]))
                    fixX.append(tem)
            print("余块:计算分类结果……")
            checkMemory(500)
            z = clf.predict(fixX)         
            Z.extend(z)
        Z = numpy.array(Z)
        #Z=Z.reshape(1,width*height)
        Z = Z.reshape(height,width)
        fixX = None
        arr = None
    else:
        checkMemory(1000)
        arr = rds.ReadAsArray(0,0,width,height)
        for i in range(height):
            print(f"添加训练样本{round((i+1)*100/height,4)}%")
            for k in range(width):
                tem = list()
                for bc in range(BandsCount):
                    tem.append(int(arr[bc][i][k]))
                fixX.append(tem)
        arr = None
        print("计算分类结果……")
        Z = clf.predict(fixX)
        Z = numpy.array(Z)
        Z = Z.reshape(height,width)
    driver = gdal.GetDriverByName("GTiff")
    filepath,filename = os.path.split(outfile)
    short,ext = os.path.splitext(filename)
    print("创建输出文件")
    out = driver.Create(outfile,width,height,1,rds.GetRasterBand(1).DataType)
    out.SetGeoTransform(transform)
    out.SetProjection(rds.GetProjectionRef())
    print("写入数据……")
    out.GetRasterBand(1).WriteArray(Z)
    out.FlushCache()
    out = None
    print("计算完成")

六、分类结果

分类图像和分类结果
利用Python的Scikit-Learn库对遥感影像进行随机森林分类利用Python的Scikit-Learn库对遥感影像进行随机森林分类

关注开放GIS实验室与防灾减灾服务公众号,获取示例数据和全部代码
利用Python的Scikit-Learn库对遥感影像进行随机森林分类