遥感影像32位转8位(python)
程序员文章站
2022-05-21 12:22:18
# -*- coding: utf-8 -*-import os, sys, timeimport numpy as npfrom osgeo import ogrfrom osgeo import gdalfrom osgeo import gdal_array as gadef stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100): """ :param bands: 目标数据,nu...
# -*- coding: utf-8 -*-
import os, sys, time
import numpy as np
from osgeo import ogr
from osgeo import gdal
from osgeo import gdal_array as ga
def stretch_n(bands, img_min, img_max, lower_percent=0, higher_percent=100):
"""
:param bands: 目标数据,numpy格式
:param img_min: 目标位深的最小值,以8bit为例,最大值为255, 最小值为0
:param img_max: 目标位深的最大值
:return:
"""
out = np.zeros_like(bands).astype(np.float32)
a = img_min
b = img_max
c = np.percentile(bands[:, :], lower_percent)
d = np.percentile(bands[:, :], higher_percent)
t = a + (bands[:, :] - c) * (b - a) / (d - c)
t[t < a] = a
t[t > b] = b
out[:, :] = t
return out
def read_img(filename):
dataset=gdal.Open(filename)
im_width = dataset.RasterXSize
im_height = dataset.RasterYSize
im_geotrans = dataset.GetGeoTransform()
im_proj = dataset.GetProjection()
im_data = dataset.ReadAsArray(0,0,im_width,im_height)
del dataset
return im_proj,im_geotrans,im_width, im_height,im_data
def write_img(filename, im_proj, im_geotrans, im_data):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1,im_data.shape
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(filename, im_width, im_height, im_bands, datatype)
dataset.SetGeoTransform(im_geotrans)
dataset.SetProjection(im_proj)
if im_bands == 1:
dataset.GetRasterBand(1).WriteArray(im_data)
else:
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
if __name__ == "__main__":
origin_32 = ['train_image1.tif', 'train_image2.tif']
target_path = 'target_8'
mkdir(target_path)
output_8 = [os.path.join(target_path, subPath) for subPath in origin_32]
for input, output in zip(origin_32, output_8):
im_proj, im_geotrans, im_width, im_height, im_data = read_img(input)
out = stretch_n(im_data, 0, 255)
write_img(output, im_proj, im_geotrans, out.astype(np.int8))
本文地址:https://blog.csdn.net/weixin_42990464/article/details/107152799