python多线程方法详解
程序员文章站
2022-03-08 18:31:15
处理多个数据和多文件时,使用for循环的速度非常慢,此时需要用多线程来加速运行进度,常用的模块为multiprocess和joblib,下面对两种包我常用的方法进行说明。1、模块安装pip insta...
处理多个数据和多文件时,使用for循环的速度非常慢,此时需要用多线程来加速运行进度,常用的模块为multiprocess和joblib,下面对两种包我常用的方法进行说明。
1、模块安装
pip install multiprocessing pip install joblib
2、以分块计算ndvi为例
首先导入需要的包
import numpy as np from osgeo import gdal import time from multiprocessing import cpu_count from multiprocessing import pool from joblib import parallel, delayed
定义gdalutil类,以读取遥感数据
class gdalutil: def __init__(self): pass @staticmethod def read_file(raster_file, read_band=none): """读取栅格数据""" # 注册栅格驱动 gdal.allregister() gdal.setconfigoption('gdal_filename_is_utf8', 'yes') # 打开输入图像 dataset = gdal.open(raster_file, gdal.ga_readonly) if dataset == none: print('打开图像{0} 失败.\n', raster_file) # 列 raster_width = dataset.rasterxsize # 行 raster_height = dataset.rasterysize # 读取数据 if read_band == none: data_array = dataset.readasarray(0, 0, raster_width, raster_height) else: band = dataset.getrasterband(read_band) data_array = band.readasarray(0, 0, raster_width, raster_height) return data_array @staticmethod def read_block_data(dataset, band_num, cols_read, rows_read, start_col=0, start_row=0): band = dataset.getrasterband(band_num) res_data = band.readasarray(start_col, start_row, cols_read, rows_read) return res_data @staticmethod def get_raster_band(raster_path): # 注册栅格驱动 gdal.allregister() gdal.setconfigoption('gdal_filename_is_utf8', 'yes') # 打开输入图像 dataset = gdal.open(raster_path, gdal.ga_readonly) if dataset == none: print('打开图像{0} 失败.\n', raster_path) raster_band = dataset.rastercount return raster_band @staticmethod def get_file_size(raster_path): """获取栅格仿射变换参数""" # 注册栅格驱动 gdal.allregister() gdal.setconfigoption('gdal_filename_is_utf8', 'yes') # 打开输入图像 dataset = gdal.open(raster_path, gdal.ga_readonly) if dataset == none: print('打开图像{0} 失败.\n', raster_path) # 列 raster_width = dataset.rasterxsize # 行 raster_height = dataset.rasterysize return raster_width, raster_height @staticmethod def get_file_geotransform(raster_path): """获取栅格仿射变换参数""" # 注册栅格驱动 gdal.allregister() gdal.setconfigoption('gdal_filename_is_utf8', 'yes') # 打开输入图像 dataset = gdal.open(raster_path, gdal.ga_readonly) if dataset == none: print('打开图像{0} 失败.\n', raster_path) # 获取输入图像仿射变换参数 input_geotransform = dataset.getgeotransform() return input_geotransform @staticmethod def get_file_proj(raster_path): """获取栅格图像空间参考""" # 注册栅格驱动 gdal.allregister() gdal.setconfigoption('gdal_filename_is_utf8', 'yes') # 打开输入图像 dataset = gdal.open(raster_path, gdal.ga_readonly) if dataset == none: print('打开图像{0} 失败.\n', raster_path) # 获取输入图像空间参考 input_project = dataset.getprojection() return input_project @staticmethod def write_file(dataset, geotransform, project, output_path, out_format='gtiff', etype=gdal.gdt_float32): """写入栅格""" if np.ndim(dataset) == 3: out_band, out_rows, out_cols = dataset.shape else: out_band = 1 out_rows, out_cols = dataset.shape # 创建指定输出格式的驱动 out_driver = gdal.getdriverbyname(out_format) if out_driver == none: print('格式%s 不支持creat()方法.\n', out_format) return out_dataset = out_driver.create(output_path, xsize=out_cols, ysize=out_rows, bands=out_band, etype=etype) # 设置输出图像的仿射参数 out_dataset.setgeotransform(geotransform) # 设置输出图像的投影参数 out_dataset.setprojection(project) # 写出数据 if out_band == 1: out_dataset.getrasterband(1).writearray(dataset) else: for i in range(out_band): out_dataset.getrasterband(i + 1).writearray(dataset[i]) del out_dataset
定义计算ndvi的函数
def cal_ndvi(multi): ''' 计算高分ndvi :param multi:格式为列表,依次包含[遥感文件路径,开始行号,开始列号,待读的行数,待读的列数] :return: ndvi数组 ''' input_file, start_col, start_row, cols_step, rows_step = multi dataset = gdal.open(input_file, gdal.ga_readonly) nir_data = gdalutil.read_block_data(dataset, 4, cols_step, rows_step, start_col=start_col, start_row=start_row) red_data = gdalutil.read_block_data(dataset, 3, cols_step, rows_step, start_col=start_col, start_row=start_row) ndvi = (nir_data - red_data) / (nir_data + red_data) ndvi[(ndvi > 1.5) | (ndvi < -1)] = 0 return ndvi
定义主函数
if __name__ == "__main__": input_file = r'd:\originaldata\gf1\namucuo2021.tif' output_file = r'd:\originaldata\gf1\namucuo2021_ndvi.tif' method = 'joblib' # method = 'multiprocessing' # 获取文件主要信息 raster_cols, raster_rows = gdalutil.get_file_size(input_file) geotransform = gdalutil.get_file_geotransform(input_file) project = gdalutil.get_file_proj(input_file) # 定义分块大小 rows_block_size = 50 cols_block_size = 50 multi = [] for j in range(0, raster_rows, rows_block_size): for i in range(0, raster_cols, cols_block_size): if j + rows_block_size < raster_rows: rows_step = rows_block_size else: rows_step = raster_rows - j # 数据横向步长 if i + cols_block_size < raster_cols: cols_step = cols_block_size else: cols_step = raster_cols - i temp_multi = [input_file, i, j, cols_step, rows_step] multi.append(temp_multi) t1 = time.time() if method == 'multiprocessing': # multiprocessing方法 pool = pool(processes=cpu_count()-1) # 注意map函数中传入的参数应该是可迭代对象,如list;返回值为list res = pool.map(cal_ndvi, multi) pool.close() pool.join() else: # joblib方法 res = parallel(n_jobs=-1)(delayed(cal_ndvi)(input_list) for input_list in multi) t2 = time.time() print("total time:" + (t2 - t1).__str__()) # 将multiprocessing中的结果提取出来,放回对应的矩阵位置中 out_data = np.zeros([raster_rows, raster_cols], dtype='float') for result, input_multi in zip(res, multi): start_col = input_multi[1] start_row = input_multi[2] cols_step = input_multi[3] rows_step = input_multi[4] out_data[start_row:start_row + rows_step, start_col:start_col + cols_step] = result gdalutil.write_file(out_data, geotransform, project, output_file)
双重for循环时,两层for循环都使用multiprocessing时会报错,这时可以外层for循环使用joblib方法,内层for循环改为multiprocessing方法,不会报错
到此这篇关于python多线程方法详解的文章就介绍到这了,更多相关python多线程内容请搜索以前的文章或继续浏览下面的相关文章希望大家以后多多支持!
上一篇: 带你了解Java数据结构和算法之二叉树
下一篇: Java的nanoTime()